add: metaretriever

This commit is contained in:
出蛰 2023-12-11 13:45:36 +08:00
parent 7c59c46682
commit c2f193af79
142 changed files with 15046 additions and 0 deletions

18
metaretriever/.gitignore vendored Normal file
View File

@ -0,0 +1,18 @@
# UIE
/data
/pretrain_data
/hf_models
/pd_models
/runs
/models
*.lock
# mac
.DS_Store
# env
/.vscode
/.idea
**/__pycache__
*.pyc
.pytest_cache

90
metaretriever/README.md Normal file
View File

@ -0,0 +1,90 @@
# Universal Information Extraction with Meta-Pretrained Self-Retrieval
This code is for ACL 2023 Findings paper "Universal Information Extraction with Meta-Pretrained Self-Retrieval".
## Overview
![](img/MetaRetriever.png)
Universal Information Extraction (Universal IE) aims to solve different extraction tasks in a uniform text-to-structure generation manner. Such a generation procedure tends to struggle when there exist complex information structures to be extracted. Retrieving knowledge from external knowledge bases may help models to overcome this problem but it is impossible to construct a knowledge base suitable for various IE tasks. Inspired by the fact that large amount of knowledge are stored in the pretrained language models (PLM) and can be retrieved explicitly, in this paper, we propose MetaRetriever to retrieve task-specific knowledge from PLMs to enhance universal IE. As different IE tasks need different knowledge, we further propose a Meta-Pretraining Algorithm which allows MetaRetriever to quicktly achieve maximum task-specific retrieval performance when fine-tuning on downstream IE tasks. Experimental results show that MetaRetriever achieves the new state-of-the-art on 4 IE tasks, 12 datasets under fully-supervised, low-resource and few-shot scenarios.
## Requirements
General
- Python (verified on 3.8)
- CUDA (verified on 10.2)
Python Packages
``` bash
conda create -n metaretriever python=3.8
conda install -y pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
```
**NOTE**: Different versions of packages (such as `pytorch`, `transformers`, etc.) may lead to different results from the paper. However, the trend should still hold no matter what versions of packages you use.
## Usage
### Data Preprocess
``` bash
cd ./dataset_processing/ours
bash download_and_preprocess_data_clean.sh > clean_log.txt
```
### Model Preparation
Please refer to [UIE](https://github.com/universal-ie/UIE) to download UIE model checkpoint and put it under the `models` dir.
### Meta-Pretraining
``` bash
bash run_seq2seq_pretrain.bash -v -d 0,1,2,3,4,5,6,7 -b 64 -k 1 --lr 1e-4 --warmup_ratio 0.06 -i relation/ours_clean --spot_noise 0.0 --asoc_noise 0.0 -f spotasoc --map_config config/offset_map/closest_offset_en.yaml -m ./models/uie-base-en --random_prompt --epoch 4 --trainer_type meta_pretrain_v2 --use_prompt_tuning_model False --output_dir output/meta-pretrained-model
```
### Meta-Finetuning
1. Full Supervision Scenario
``` bash
. config/exp_conf/large_model_conf.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=4 use_prompt_tuning_model=False run_time=1 bash scripts_exp/run_exp.bash
```
2. Few-Shot Scenario
``` bash
. config/exp_conf/base_model_conf_sa_shot.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=16 use_prompt_tuning_model=False bash scripts_exp/run_exp_shot.bash
```
3. Low-Resource Scenario
``` bash
. config/exp_conf/base_model_conf_sa_ratio.ini && trainer_type=meta_finetune_v2 model_name=meta-pretrained-model dataset_name=relation/conll04 selected_gpus=0,1,2,3,4,5,6,7 BATCH_SIZE=16 use_prompt_tuning_model=False bash scripts_exp/run_exp_ratio.bash
```
## Citation
If this repository helps you, please cite this paper:
```
@inproceedings{cong-etal-2023-universal,
title = "Universal Information Extraction with Meta-Pretrained Self-Retrieval",
author = "Cong, Xin and
Yu, Bowen and
Fang, Mengcheng and
Liu, Tingwen and
Yu, Haiyang and
Hu, Zhongkai and
Huang, Fei and
Li, Yongbin and
Wang, Bin",
editor = "Rogers, Anna and
Boyd-Graber, Jordan and
Okazaki, Naoaki",
booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
month = jul,
year = "2023",
address = "Toronto, Canada",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.findings-acl.251",
doi = "10.18653/v1/2023.findings-acl.251",
}
```

View File

@ -0,0 +1,16 @@
export k8s_gpu_cards=1
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=5
export max_source_length=384
export BATCH_SIZE="16"
export LR_RATE="1e-4 3e-4 5e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'

View File

@ -0,0 +1,21 @@
export job_name=FT_Multi
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export BATCH_SIZE="16"
export LR_RATE="3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags=""
export job_remark="3e-4,0.1"
export eval_match_mode="set"

View File

@ -0,0 +1,23 @@
export k8s_gpu_cards=1
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=5
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="16"
export LR_RATE="5e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="5e-4,0.1"
export start_eval_step=3000

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="1e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1 0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="1e-4,0.1,0.2"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="1e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1 0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="1e-4,0.1,0.2"

View File

@ -0,0 +1,22 @@
export job_name=FT_spotasocname
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=2000
export epoch=50
export run_time=3
export max_source_length=256
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="1e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'
export start_eval_step=15000
export job_remark="1e-4,0.1"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="1e-4 3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="1e-4,3e-4,0.2"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=256
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="3e-4,0.2"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=256
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="1e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/first_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="1e-4,0.1"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=1
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=256
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="5e-5"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/first_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="5e-5,0.1"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2 0.1"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="3e-4,0.1,0.2"

View File

@ -0,0 +1,16 @@
export k8s_gpu_cards=1
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export BATCH_SIZE="8"
export LR_RATE="1e-4 3e-5 5e-5"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'

View File

@ -0,0 +1,23 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="5e-5"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="5e-5,0.2"
export eval_match_mode="set"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2 0.1 0"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="3e-4,0,0.1,0.2"

View File

@ -0,0 +1,22 @@
description: uie
environment:
image: docker.cipsup.cn/uie/uie:transformers4.6.2
environment_variables:
- DET_TASK_OWNER=luyaojie
resources:
slots: 4
bind_mounts:
# Data Folder Bind Mount
- host_path: /shared_home/luyaojie/uie/data
container_path: /run/determined/workdir/data
# Pre-trained Model Folder Bind Mount
- host_path: /shared_home/luyaojie/uie/model
container_path: /run/determined/workdir/hf_models
# Output Folder Bind Mount
- host_path: /shared_home/luyaojie/uie/output
container_path: /run/determined/workdir/output

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="16"
export LR_RATE="1e-4 3e-4 5e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0 0.1 0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="1e-4,3e-4,5e-5,0,0.1,0.2"

View File

@ -0,0 +1,21 @@
export gpu_node=1
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="16"
export LR_RATE="1e-4 3e-4 5e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1 0 0.2"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="1e-4,3e-4,5e-5,0,0.1,0.2"

View File

@ -0,0 +1,15 @@
export gpu_node=1
export eval_steps=0
export epoch=200
export run_time=10
export max_source_length=256
export decoding_format="spotasoc"
export BATCH_SIZE="16"
export LR_RATE="1e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'

View File

@ -0,0 +1,15 @@
export gpu_node=1
export eval_steps=0
export epoch=200
export run_time=10
export max_source_length=256
export decoding_format="spotasoc"
export BATCH_SIZE="16"
export LR_RATE="1e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.1"
export map_config='config/offset_map/closest_offset_en.yaml'

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="5e-5 1e-4 3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2 0.1 0"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="5e-5,1e-4,3e-4,0,0.1,0.2"

View File

@ -0,0 +1,22 @@
export k8s_gpu_cards=4
export gpu_node=${k8s_gpu_cards}
export eval_steps=0
export epoch=50
export run_time=3
export max_source_length=384
export job_tags="${dataset_name},${model_name}_rp"
export job_remark="d${dataset_name},m${model_name}"
export BATCH_SIZE="8"
export LR_RATE="5e-5 1e-4 3e-4"
export WARMUP_PROP="0.06"
export LABEL_SMOOTHING="0"
export NEGATIVE="-1"
export NOISE="0.2 0.1 0"
export map_config='config/offset_map/closest_offset_en.yaml'
export job_tags="${dataset_name},${model_name}"
export job_remark="5e-5,1e-4,3e-4,0,0.1,0.2"

View File

@ -0,0 +1,3 @@
map_strategy: "closest"
de_duplicate: True
span_to_token: "space"

View File

@ -0,0 +1,3 @@
map_strategy: "closest"
de_duplicate: True
span_to_token: "list"

View File

@ -0,0 +1,3 @@
map_strategy: "first"
de_duplicate: True
span_to_token: "space"

View File

@ -0,0 +1,3 @@
map_strategy: "first"
de_duplicate: True
span_to_token: "list"

View File

@ -0,0 +1,3 @@
map_strategy: "longer_first"
de_duplicate: True
span_to_token: "list"

View File

@ -0,0 +1,25 @@
/data
/converted_data
/lightning_logs
/model
/models
/*log
/thirdparty
/tmp
.lock
# mac
.DS_Store
# env
/.vscode
/.idea
**/__pycache__
*.pyc
.pytest_cache
# doc
docs/build
docs/.vscode
docs/source/_build

View File

@ -0,0 +1,3 @@
# Universal IE Dataset Preparation
Please refer to [UIE](https://github.com/universal-ie/UIE).

View File

@ -0,0 +1,15 @@
name: 14lap
path: data/absa/pengb/14lap
data_class: ABSA
split:
train: train_convert.json
val: dev_convert.json
test: test_convert.json
language: en
mapper:
POS: positive
NEG: negative
NEU: neutral
aspect: aspect
opinion: opinion

View File

@ -0,0 +1,15 @@
name: 14res
path: data/absa/pengb/14res
data_class: ABSA
split:
train: train_convert.json
val: dev_convert.json
test: test_convert.json
language: en
mapper:
POS: positive
NEG: negative
NEU: neutral
aspect: aspect
opinion: opinion

View File

@ -0,0 +1,15 @@
name: 15res
path: data/absa/pengb/15res
data_class: ABSA
split:
train: train_convert.json
val: dev_convert.json
test: test_convert.json
language: en
mapper:
POS: positive
NEG: negative
NEU: neutral
aspect: aspect
opinion: opinion

View File

@ -0,0 +1,15 @@
name: 16res
path: data/absa/pengb/16res
data_class: ABSA
split:
train: train_convert.json
val: dev_convert.json
test: test_convert.json
language: en
mapper:
POS: positive
NEG: negative
NEU: neutral
aspect: aspect
opinion: opinion

View File

@ -0,0 +1,13 @@
name: conll03
path: data/conll03/conll03
data_class: CoNLL03
split:
train: eng.train
val: eng.testa
test: eng.testb
language: en
mapper:
LOC: location
ORG: organization
PER: person
MISC: miscellaneous

View File

@ -0,0 +1,17 @@
name: mrc_ace04
path: data/mrc_ner/ace2004
data_class: MRCNER
split:
train: mrc-ner.train
val: mrc-ner.dev
test: mrc-ner.test
language: en
mapper:
FAC: facility
GPE: geographical social political
LOC: location
ORG: organization
PER: person
VEH: vehicle
WEA: weapon

View File

@ -0,0 +1,17 @@
name: mrc_ace05
path: data/mrc_ner/ace2005
data_class: MRCNER
split:
train: mrc-ner.train
val: mrc-ner.dev
test: mrc-ner.test
language: en
mapper:
FAC: facility
GPE: geographical social political
LOC: location
ORG: organization
PER: person
VEH: vehicle
WEA: weapon

View File

@ -0,0 +1,62 @@
name: casie
path: data/casie
data_class: CASIE
split:
train: train.jsonlines
val: dev.jsonlines
test: test.jsonlines
language: en
# https://github.com/Ebiquity/CASIE
# https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9006444
mapper:
File: file
System: system
Person: person
Phishing: phishing
Data: data
Purpose: purpose
Website: website
Organization: organization
Capabilities: capabilities
Malware: malware
Software: software
PII: personally identifiable information
Databreach: databreach
Time: time
Number: number
GPE: geopolitical entity
Ransom: ransom
Money: money
Device: device
Vulnerability: vulnerability
DiscoverVulnerability: discover vulnerability
Patch: patch
PatchVulnerability: patch vulnerability
Version: version
PaymentMethod: payment method
CVE: common vulnerabilities and exposures
Issues-Addressed: issues addressed
Vulnerable_System: vulnerable system
Number-of-Data: number of data
Capabilities: capabilities
Patch: patch
Time: time
Vulnerable_System_Version: vulnerable system version
Releaser: releaser
Damage-Amount: damage amount
Number-of-Victim: number of victim
Tool: tool
Attack-Pattern: attack pattern
Compromised-Data: compromised data
Attacker: attacker
Price: price
Discoverer: discoverer
Patch-Number: patch number
Payment-Method: payment method
Supported_Platform: supported platform
Vulnerability: vulnerability
Place: place
Vulnerable_System_Owner: vulnerable system owner
Victim: victim
Trusted-Entity: trusted entity
Purpose: purpose

View File

@ -0,0 +1,78 @@
name: oneie_ace05_en_event
path: data/oneie/ace05-EN
data_class: OneIEEvent
split:
train: train.oneie.json
val: dev.oneie.json
test: test.oneie.json
language: en
mapper:
FAC: facility
GPE: geographical social political
LOC: location
ORG: organization
PER: person
VEH: vehicle
WEA: weapon
ORG-AFF: organization affiliation
GEN-AFF: general affiliation
PHYS: physical
PART-WHOLE: part whole
PER-SOC: personal social
ART: agent artifact
Personnel:Elect: elect
Life:Be-Born: born
Movement:Transport: transport
Contact:Phone-Write: phone write
Life:Marry: marry
Life:Die: die
Personnel:Start-Position: start position
Life:Injure: injure
Transaction:Transfer-Ownership: transfer ownership
Contact:Meet: meet
Personnel:Nominate: nominate
Conflict:Attack: attack
Business:Start-Org: start organization
Justice:Trial-Hearing: trial hearing
Justice:Convict: convict
Justice:Sentence: sentence
Personnel:End-Position: end position
Life:Divorce: divorce
Justice:Acquit: acquit
Justice:Charge-Indict: charge indict
Transaction:Transfer-Money: transfer money
Justice:Appeal: appeal
Justice:Sue: sue
Business:Merge-Org: merge organization
Business:Declare-Bankruptcy: declare bankruptcy
Justice:Execute: execute
Justice:Arrest-Jail: arrest jail
Justice:Extradite: extradite
Conflict:Demonstrate: demonstrate
Business:End-Org: end organization
Justice:Release-Parole: release parole
Justice:Fine: fine
Justice:Pardon: pardon
Defendant: defendant
Prosecutor: prosecutor
Person: person
Origin: origin
Buyer: buyer
Plaintiff: plaintiff
Victim: victim
Org: organization
Adjudicator: adjudicator
Seller: seller
Beneficiary: beneficiary
Giver: giver
Target: target
Agent: agent
Instrument: instrument
Vehicle: vehicle
Entity: entity
Destination: destination
Recipient: recipient
Attacker: attacker
Artifact: artifact
Place: place

View File

@ -0,0 +1,36 @@
name: NYT
path: data/NYT-multi
data_class: JointER
split:
train: train.json
val: dev.json
test: test.json
language: en
mapper:
ORGANIZATION: organization
LOCATION: location
PERSON: person
/location/location/contains: contains
/people/person/place_of_birth: place of birth
/business/person/company: company
/people/person/place_lived: place lived
/location/administrative_division/country: country
/location/country/administrative_divisions: administrative divisions
/people/person/religion: religion
/people/person/nationality: nationality
/people/person/children: children
/location/country/capital: capital
/business/company/place_founded: place founded
/people/deceased_person/place_of_death: place of death
/business/company/founders: founders
/location/neighborhood/neighborhood_of: neighborhood of
/business/company/advisors: advisors
/people/ethnicity/geographic_distribution: geographic distribution
/sports/sports_team_location/teams: teams
/sports/sports_team/location: location
/business/company_shareholder/major_shareholder_of: major shareholder of
/business/company/major_shareholders: major shareholders
/people/person/ethnicity: ethnicity
/people/ethnicity/people: people
/people/person/profession: profession
/business/company/industry: industry

View File

@ -0,0 +1,23 @@
name: ace05-rel
path: data/spannet_data/relation/ace05
data_class: Spannet
split:
train: train.jsonlines
val: dev.jsonlines
test: test.jsonlines
language: en
mapper:
FAC: facility
GPE: geographical social political
LOC: location
ORG: organization
PER: person
VEH: vehicle
WEA: weapon
ORG-AFF: organization affiliation
GEN-AFF: general affiliation
PHYS: physical
PART-WHOLE: part whole
PER-SOC: personal social
ART: agent artifact

View File

@ -0,0 +1,19 @@
name: conll04
path: data/spannet_data/relation/conll04
data_class: Spannet
split:
train: train.jsonlines
val: dev.jsonlines
test: test.jsonlines
language: en
mapper:
Loc: location
Org: organization
Other: other
Peop: people
OrgBased_In: organization in
Work_For: work for
Located_In: located in
Live_In: live in
Kill: kill

View File

@ -0,0 +1,22 @@
name: scierc
path: data/spannet_data/relation/dygiepp/scierc
data_class: Spannet
split:
train: train.jsonlines
val: dev.jsonlines
test: test.jsonlines
language: en
mapper:
Method: method
Generic: generic
Material: material
Task: task
Metric: metric
OtherScientificTerm: other scientific term
USED-FOR: used for
FEATURE-OF: feature of
COMPARE: compare
EVALUATE-FOR: evaluate for
CONJUNCTION: conjunction
HYPONYM-OF: hyponym of
PART-OF: part of

View File

@ -0,0 +1,10 @@
# 数据统计脚本
``` bash
python scripts/data_statistics.py \
-data converted_data/text2spotasoc/
-f csv
```
- data: 目标文件夹,遍历文件夹下包含 record.schema 的子文件夹,跳过所有的命名中包含 shot 和 rario 的文件夹
- f: 输出的表格形式,常见中 simple默认latexhtml

View File

@ -0,0 +1,54 @@
# 低资源数据采样
详细脚本见 `run_sample.bash`, 自动生成所有数据
## 低数据比例采样
``` text
$ python scripts/sample_data_ratio.py -h
usage: sample_data_ratio.py [-h] [-src SRC] [-tgt TGT] [-seed SEED]
optional arguments:
-h, --help show this help message and exit
-src SRC
-tgt TGT
-seed SEED
```
样例:
``` bash
python scripts/sample_data_ratio.py \
-src converted_data/text2spotasoc/entity/mrc_conll03 \
-tgt test_conll03_ratio
```
对所有数据文件夹的train.json取指定 0.01 0.05 0.1 比例的数据
## N-shot 数据采样
``` text
$ python scripts/sample_data_shot.py -h
usage: sample_data_shot.py [-h] -src SRC -tgt TGT -task {entity,relation,event} [-seed SEED]
optional arguments:
-h, --help show this help message and exit
-src SRC Source Folder Name
-tgt TGT Target Folder Name, n shot sampled
-task {entity,relation,event}
N-Shot Task name
-seed SEED Default is None, no random
```
样例:
``` bash
python scripts/sample_data_shot.py \
-src converted_data/text2spotasoc/entity/mrc_conll03 \
-tgt test_conll03_shot \
-task entity
```
1. 读取数据文件夹的 `entity.schema`
2. 根据每个类别采样 1 5 10 个样例合成最终数据

View File

@ -0,0 +1,131 @@
import json
import os
import random
import argparse
from collections import OrderedDict
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str)
parser.add_argument("--entity_category_dir", default="entity_category", type=str)
parser.add_argument("--relation_category_dir", default="relation_category", type=str)
parser.add_argument("--step", default=1, type=int)
opt = parser.parse_args()
data_dir = opt.data_dir
output_dir = opt.output_dir
entity_category_dir = opt.entity_category_dir
relation_category_dir = opt.relation_category_dir
step = opt.step
all_file = os.path.join(output_dir, "all.json")
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
target_entity_category_dir = os.path.join(output_dir, entity_category_dir)
target_relation_category_dir = os.path.join(output_dir, relation_category_dir)
if not os.path.exists(target_entity_category_dir):
os.makedirs(target_entity_category_dir)
if not os.path.exists(target_relation_category_dir):
os.makedirs(target_relation_category_dir)
entity_instance_dict_file = os.path.join(output_dir, "entity_instance_dict.json")
relation_instance_dict_file = os.path.join(output_dir, "relation_instance_dict.json")
metainfo_file = relation_instance_dict_file = os.path.join(output_dir, "metainfo.json")
# %% load all instance line
print("Reading all data...")
instance_list = []
with open(all_file) as all:
for idx, line in tqdm(enumerate(all)):
if len(line) == 0:
continue
instance_list.append(line)
print("All data read.")
# %% rearrange instance by class
print("Stat entity type and relation type...")
entity_type_instance_dict = {}
relation_type_instance_dict = {}
for line in tqdm(instance_list):
if len(line) == 0:
continue
record = json.loads(line)
entity_type_list = record["spot"]
relation_type_list = record["asoc"]
for entity_type in entity_type_list:
if entity_type not in entity_type_instance_dict:
entity_type_instance_dict[entity_type] = {
"type_id": len(entity_type_instance_dict),
"instance_list": []
}
entity_type_instance_dict[entity_type]["instance_list"].append(line)
for relation_type in relation_type_list:
if relation_type not in relation_type_instance_dict:
relation_type_instance_dict[relation_type] = {
"type_id": len(relation_type_instance_dict),
"instance_list": []
}
relation_type_instance_dict[relation_type]["instance_list"].append(line)
print("Stat over.")
# %% save data by category
metainfo = {
"entity": [],
"relation": [],
}
print("Saving entity by category...")
for entity_type, data in tqdm(entity_type_instance_dict.items()):
type_id = data["type_id"]
instance_list = data["instance_list"]
current_metainfo = {
"entity_type": entity_type,
"type_id": type_id
}
metainfo["entity"].append(current_metainfo)
entity_type_file = os.path.join(target_entity_category_dir, str(type_id)+".json")
with open(entity_type_file, "w") as f:
for instance in instance_list:
f.write(instance)
print("Entity saved.")
print("Saving relation by category...")
for relation_type, data in tqdm(relation_type_instance_dict.items()):
type_id = data["type_id"]
instance_list = data["instance_list"]
current_metainfo = {
"relation_type": relation_type,
"type_id": type_id
}
metainfo["relation"].append(current_metainfo)
relation_type_file = os.path.join(target_relation_category_dir, str(type_id)+".json")
with open(relation_type_file, "w") as f:
for instance in instance_list:
f.write(instance)
print("Relation saved.")
print("Saving metainfo...")
with open(metainfo_file, "w") as f:
json.dump(metainfo, f)
print("Metainfo saved.")

View File

@ -0,0 +1,38 @@
import os
import json
import argparse
from tqdm import tqdm
in_files = [
'original_train.json',
]
out_files = [
'train.json',
]
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dir", default="output", type=str)
opt = parser.parse_args()
dir_path = opt.dir
for in_file, out_file in zip(in_files, out_files):
in_file_path = os.path.join(dir_path, in_file)
out_file_path = os.path.join(dir_path, out_file)
print(f"{in_file_path} -> {out_file_path}")
fin = open(in_file_path)
fout = open(out_file_path,'w')
for line in tqdm(fin):
obj = json.loads(line)
flag = 0
tmp_relation = []
for tmp_obj in obj['relation']:
tmp = json.dumps(tmp_obj)
tmp_relation.append(tmp)
obj['relation'] = tmp_relation
fout.write(json.dumps(obj)+"\n")
fin.close()
fout.close()

View File

@ -0,0 +1,29 @@
if [ ! -e out_clean.zip ];
then
echo "downloading out_clean ..."
wget -c http://url/to/dataset/out_clean.zip
else
echo "out_clean has been downloaded."
fi
if [ ! -d out_clean ];
then
echo "unziping out_clean"
unzip out_clean.zip
else
echo "out_clean has been unzipped"
fi
# preprocess
python explore.py --data_dir ./out_clean --output_dir ./output --max_instance_num -1
# fewshot sampling
python stat_category.py --source_dir ./output --output_dir ./output
python partition.py --source_dir ./output --output_dir ./output
python match.py --source_dir ./output --output_dir ./output --step 100
python rearrange_dataset.py --source_dir ./output --output_dir ./output
# generate dataset
python noise.py --output_dir ./output --all_file rearrange_all.json
python change_data_format_for_relation.py -d ./output
ln -s ../../../ours/output ../converted_data/text2spotasoc/relation/ours

View File

@ -0,0 +1,323 @@
import json
import os
import random
import argparse
from tqdm import tqdm
from nltk.tokenize import WordPunctTokenizer
word_tokenizer = WordPunctTokenizer()
import numpy as np
np.set_printoptions(suppress=True)
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
parser.add_argument("-o", "--output_dir", default="./output", type=str)
parser.add_argument("-n", "--max_instance_num", default=-1, type=int)
opt = parser.parse_args()
data_dir = opt.data_dir
output_dir = opt.output_dir
max_instance_num = opt.max_instance_num
entity_schema_file = os.path.join(output_dir, "entity.schema")
relation_schema_file = os.path.join(output_dir, "relation.schema")
event_schema_file = os.path.join(output_dir, "event.schema")
record_schema_file = os.path.join(output_dir, "record.schema")
all_file = os.path.join(output_dir, "all.json")
train_file = os.path.join(output_dir, "original_train.json")
dev_file = os.path.join(output_dir, "original_val.json")
test_file = os.path.join(output_dir, "original_test.json")
ENTITY_SEARCH_RANGE = 0
ALL_ENTITY_CNT = 0
NOMATCH_ENTITY_CNT = 0
NON_OFFSET_ENTITY_CNT = 0
def word_tokenize(text):
return word_tokenizer.tokenize(text)
def record2instance(record):
instance = {
"text": None,
"tokens": None,
"record": None,
"entity": None,
"relation": None,
"event": [],
"spot": None,
"asoc": None,
"spot_asoc": None,
}
# create text field
text = record["sentence_value"]
instance["text"] = text
# create tokens field
tokens = word_tokenize(text)
text_length_list.append(len(tokens))
instance["tokens"] = tokens
# create entity field
entities = record["sentence_entities"]
instance_entity_list = []
for entity in entities:
entity_uri = entity["uri"]
entity_mention = entity["surfaceform"]
entity_type = entity["tag"]
entity_offset = entity["boundaries_token"]
if entity_type == "#dateTime":
entity_type = "date time"
elif entity_type == "#decimal":
entity_type = "decimal"
elif entity_type == "":
entity_type = "other"
if entity_mention == "":
continue
try:
start_index, end_index = entity_offset[0], entity_offset[-1]
except:
global NON_OFFSET_ENTITY_CNT
NON_OFFSET_ENTITY_CNT += 1
return None
current_mention = " ".join(tokens[start_index:end_index+1])
original_mention = " ".join(word_tokenize(entity_mention))
if current_mention != original_mention:
global NOMATCH_ENTITY_CNT
NOMATCH_ENTITY_CNT += 1
global ALL_ENTITY_CNT
ALL_ENTITY_CNT += 1
entity_offset = list(range(start_index, end_index+1))
instance_entity = {
"type": entity_type,
"offset": entity_offset,
"text": entity_mention,
"uri": entity_uri
}
instance_entity_list.append(instance_entity)
instance["entity"] = instance_entity_list
# create spot field
instance_entity_type_list = [i["type"] for i in instance_entity_list]
instance["spot"] = list(set(instance_entity_type_list))
entity_type_list.extend(instance_entity_type_list)
# create relation field
triples = record["sentence_triples"]
instance_relation_list = []
for triple in triples:
subj = triple["subject"]
obj = triple["object"]
predicate = triple["predicate"]
relation_type = predicate["surfaceform"]
try:
head_entity = [i for i in instance_entity_list if i["uri"] == subj["uri"]][0]
except IndexError:
continue
try:
tail_entity = [i for i in instance_entity_list if i["uri"] == obj["uri"]][0]
except IndexError:
continue
head_entity_type = head_entity["type"]
tail_entity_type = tail_entity["type"]
triple_type = (head_entity_type, relation_type, tail_entity_type)
triple_type_list.append(triple_type)
instance_relation = {
"type": relation_type,
"args": [
head_entity,
tail_entity
]
}
instance_relation_list.append(instance_relation)
instance["relation"] = instance_relation_list
# create asoc field
instance_asoc_list = [i["type"] for i in instance_relation_list]
instance["asoc"] = list(set(instance_asoc_list))
relation_list.extend(instance_asoc_list)
# create spot_asoc field
instance_spot_asoc_list = []
for entity in instance_entity_list:
instance_spot_asoc = {
"span": entity["text"],
"label": entity["type"],
"asoc": []
}
for triple in instance_relation_list:
if triple["args"][0]["uri"] == entity["uri"]:
asoc_record = [triple["type"], triple["args"][1]["text"]]
instance_spot_asoc["asoc"].append(asoc_record)
instance_spot_asoc_list.append(instance_spot_asoc)
instance["spot_asoc"] = instance_spot_asoc_list
# create record field
instance_record = "<extra_id_0> "
for instance_spot_asoc in instance_spot_asoc_list:
instance_record += "<extra_id_0> "
instance_record += instance_spot_asoc["label"] + " "
instance_record += "<extra_id_5> "
instance_record += instance_spot_asoc["span"] + " "
if len(instance_spot_asoc["asoc"]) != 0:
for asoc in instance_spot_asoc["asoc"]:
instance_record += "<extra_id_0> "
instance_record += asoc[0] + " "
instance_record += "<extra_id_5> "
instance_record += asoc[1] + " "
instance_record += "<extra_id_1> "
instance_record += "<extra_id_1> "
instance_record += "<extra_id_1>"
instance["record"] = instance_record
return instance
# %% read data
file_list = os.listdir(data_dir)
text_length_list = []
record_cnt = 0
relation_list = []
entity_type_list = []
triple_type_list = []
json_str_length_list = []
instance_num = 0
with open(all_file, "w") as all:
for file_name in tqdm(file_list):
file_path = os.path.join(data_dir, file_name)
with open(file_path) as f:
for line in f:
if len(line) == 0:
continue
record = json.loads(line)
record_cnt += 1
instance = record2instance(record)
if instance is None:
continue
json_str = json.dumps(instance)
json_str_length_list.append(len(json_str))
all.write(json_str + "\n")
instance_num += 1
if max_instance_num != -1 and instance_num == max_instance_num:
break
if max_instance_num != -1 and instance_num == max_instance_num:
break
print(f"Total number of all entities: {ALL_ENTITY_CNT}")
print(f"Those entities non-match raw text: {NOMATCH_ENTITY_CNT}")
print(f"Non-match rate: {NOMATCH_ENTITY_CNT / ALL_ENTITY_CNT}")
print(f"Total number of all non-offset entities: {NON_OFFSET_ENTITY_CNT}")
print(f"Non-offset rate: {NON_OFFSET_ENTITY_CNT / ALL_ENTITY_CNT}")
print(f"Total record: {record_cnt}")
print(f"Total instance: {instance_num}")
print()
# %% stat of text length
max_len = max(text_length_list)
min_len = min(text_length_list)
print(f"Max length: {max_len}, Min length: {min_len}")
bins = 20
hist, bin_edges = np.histogram(text_length_list, bins=bins, density=False)
print("Hist:", hist)
print("Edge:", bin_edges)
satisfied_length_cnt = len([i for i in text_length_list if i <= 512])
print(f"Satisfied length cnt: {satisfied_length_cnt} ({satisfied_length_cnt/len(text_length_list)})")
print()
# %% stat of json string length
max_json_len = max(json_str_length_list)
min_json_len = min(json_str_length_list)
print(f"Max json length: {max_json_len}, Min json length: {min_json_len}")
bins = 20
json_hist, json_bin_edges = np.histogram(json_str_length_list, bins=bins, density=False)
print("Hist:", json_hist)
print("Edge:", json_bin_edges)
satisfied_json_length_cnt = len([i for i in json_str_length_list if i <= 4096])
print(f"Satisfied json length cnt: {satisfied_json_length_cnt} ({satisfied_json_length_cnt/len(json_str_length_list)})")
print()
# %% create schema
entity_type_list = list(set(entity_type_list))
relation_list = list(set(relation_list))
print(f"Num of entity type: {len(entity_type_list)}")
print(f"Num of relation type: {len(relation_list)}")
record_type_list = {}
for head_entity_type, realtion_type, tail_entity_type in triple_type_list:
if record_type_list.get(head_entity_type) is None:
record_type_list[head_entity_type] = []
record_type_list[head_entity_type].append(realtion_type)
for head_entity_type, record_relation_list in record_type_list.items():
record_type_list[head_entity_type] = list(set(record_relation_list))
with open(entity_schema_file, "w") as f:
f.write(json.dumps(entity_type_list) + "\n")
f.write(json.dumps([]) + "\n")
f.write(json.dumps({}) + "\n")
print("entity.schema saved")
with open(relation_schema_file, "w") as f:
f.write(json.dumps(relation_list) + "\n")
f.write(json.dumps(entity_type_list) + "\n")
f.write(json.dumps({i: [] for i in relation_list}) + "\n")
print("relation.schema saved")
with open(event_schema_file, "w") as f:
f.write(json.dumps([]) + "\n")
f.write(json.dumps([]) + "\n")
f.write(json.dumps({}) + "\n")
print("event.schema saved")
with open(record_schema_file, "w") as f:
f.write(json.dumps(entity_type_list) + "\n")
f.write(json.dumps(relation_list) + "\n")
f.write(json.dumps(record_type_list) + "\n")
print("record.schema saved")
print()

View File

@ -0,0 +1,117 @@
import os
import json
import math
import time
import argparse
from tqdm import tqdm
import networkx as nx
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--source_dir", default="./", type=str)
parser.add_argument("-o", "--output_dir", default="./", type=str)
parser.add_argument("--step", default=100, type=int)
opt = parser.parse_args()
source_dir = opt.source_dir
output_dir = opt.output_dir
step = opt.step
instance_label_file = os.path.join(output_dir, "instance_label.json")
partition_file = os.path.join(output_dir, "partition.json")
match_group_file = os.path.join(output_dir, "match_group.json")
# %%
print("Loading partition...")
partition = []
with open(partition_file) as f:
for line in f:
partition.append(json.loads(line))
# %%
print("Loading instance label list...")
instance_label_list = []
with open(instance_label_file) as f:
for line in tqdm(f):
instance_label = json.loads(line)
instance_label_list.append(instance_label)
instance_label_dict = {i: j for i, j in instance_label_list}
total = len(instance_label_dict)
# %%
def score(x_label, y_label, add_coef=True):
x_label = set(x_label)
y_label = set(y_label)
y2x_score = len(x_label & y_label) / len(x_label)
if add_coef:
y2x_score += 1 / len(y_label)
x2y_score = len(x_label & y_label) / len(y_label)
if add_coef:
x2y_score += + 1 / len(x_label)
if x2y_score > y2x_score:
final_score = x2y_score
flag = True
else:
final_score = y2x_score
flag = False
return final_score, flag
# %%
print("Matching...")
match_group = []
for curr_partition in tqdm(partition):
type_name, category, instance_list = curr_partition
if len(instance_list) == 1:
match_group.append((instance_list[0], instance_list[0], 1.0))
else:
# pdb.set_trace()
total_epoch = math.ceil(len(instance_list) / step)
for epoch in tqdm(range(total_epoch), leave=False):
batch = instance_list[epoch*step:(epoch+1)*step]
edges = []
for i in range(len(batch)):
for j in range(i+1, len(batch)):
x_id, y_id = batch[i], batch[j]
x_label = instance_label_dict[x_id]
y_label = instance_label_dict[y_id]
edge_weight, _ = score(x_label, y_label)
edges.append((x_id, y_id, edge_weight))
G = nx.Graph()
G.add_weighted_edges_from(edges)
match_result = nx.max_weight_matching(G)
for edge in match_result:
x_id, y_id = edge
x_label = instance_label_dict[x_id]
y_label = instance_label_dict[y_id]
match_score, flag = score(x_label, y_label, add_coef=False)
if flag:
match_group.append((x_id, y_id, match_score))
else:
match_group.append((y_id, x_id, match_score))
scores = [i[-1] for i in match_group]
average_score = sum(scores) / len(scores)
print(f"Average match score: {average_score}")
print("Saving match group...")
with open(match_group_file, "w") as f:
for record in match_group:
f.write(json.dumps(record)+"\n")

View File

@ -0,0 +1,242 @@
import json
import os
import random
import argparse
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import pdb
seed = 0
random.seed(seed)
np.random.seed(seed)
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_dir", default="./output", type=str)
parser.add_argument("-a", "--all_file", default="all.json", type=str)
parser.add_argument("-n", "--noise", default=4, type=int)
opt = parser.parse_args()
output_dir = opt.output_dir
all_file = opt.all_file
noise = opt.noise
original_all_file = os.path.join(output_dir, all_file)
noised_all_file = os.path.join(output_dir, "noised_all.json")
train_file = os.path.join(output_dir, "original_train.json")
dev_file = os.path.join(output_dir, "original_val.json")
test_file = os.path.join(output_dir, "original_test.json")
# %% noise function
NOISE_NUM = noise
THRESHOLD = 0.8
TRIPLE_THRESHOLD = [0.6, 0.8]
DECAY_COEF = 0.8
NOISE_OFFSET_THRESHOLD = 3
NOISE_OFFSET_RANGE = list(range(NOISE_OFFSET_THRESHOLD))
NOISE_OFFSET_WEIGHT = np.exp(- DECAY_COEF * np.array(NOISE_OFFSET_RANGE))
NOISE_OFFSET_WEIGHT = NOISE_OFFSET_WEIGHT / NOISE_OFFSET_WEIGHT.sum()
def noise_entity_type(entity_list):
entity_type_list = []
for entity in entity_list:
entity_type_list.append(entity["type"])
entity_type_list = list(set(entity_type_list))
noised_entity_list = []
for entity in entity_list:
noised_entity = deepcopy(entity)
if np.random.rand() > THRESHOLD:
noised_entity_type = random.choice(entity_type_list)
noised_entity["type"] = noised_entity_type
noised_entity_list.append(noised_entity)
return noised_entity_list
def noise_entity_offset(entity_list, tokens):
noised_entity_list = []
for entity in entity_list:
noised_entity = deepcopy(entity)
entity_offset = noised_entity["offset"]
start_index, end_index = entity_offset[0], entity_offset[-1]
start_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT)
end_noise = np.random.choice(NOISE_OFFSET_RANGE, p=NOISE_OFFSET_WEIGHT)
noised_start_index = max(start_index-start_noise, 0)
noised_end_index = min(end_index+end_noise, len(tokens)-1)
noised_entity_offset = list(range(noised_start_index, noised_end_index+1))
noised_entity_mention = " ".join(tokens[noised_start_index:noised_end_index+1])
noised_entity["offset"] = noised_entity_offset
noised_entity["text"] = noised_entity_mention
noised_entity_list.append(noised_entity)
return noised_entity_list
def noise_entity_with_other_entity(entity_list):
type_entity_mapping = {}
for entity in entity_list:
entity_type = entity["type"]
if entity_type not in type_entity_mapping:
type_entity_mapping[entity_type] = []
type_entity_mapping[entity_type].append(entity)
noised_entity_list = []
for entity in entity_list:
noised_entity = deepcopy(entity)
if np.random.rand() > THRESHOLD:
entity_type = noised_entity["type"]
other_entity = random.choice(type_entity_mapping[entity_type])
noised_entity["text"] = other_entity["text"]
noised_entity["offset"] = other_entity["offset"]
noised_entity_list.append(noised_entity)
return noised_entity_list
def noise_relation_type(triple_list):
relation_type_list = []
for triple in triple_list:
relation_type_list.append(triple["type"])
relation_type_list = list(set(relation_type_list))
noised_triple_list = []
for triple in triple_list:
noised_triple = deepcopy(triple)
if np.random.rand() > THRESHOLD:
noised_relation_type = random.choice(relation_type_list)
noised_triple["type"] = noised_relation_type
noised_triple_list.append(noised_triple)
return noised_triple_list
def noise_triple_num(triple_list, entity_list):
noised_triple_list = []
for triple in triple_list:
p = np.random.rand()
if p < TRIPLE_THRESHOLD[0]: # do nothing
noised_triple_list.append(triple)
elif p < TRIPLE_THRESHOLD[1]: # add noised triple
noised_triple_list.append(triple)
noised_triple = deepcopy(triple)
replaced_tail = random.choice(entity_list)
noised_triple["args"][1] = replaced_tail
noised_triple_list.append(noised_triple)
else: # remove triple
pass
return noised_triple_list
# %% utils
def build_entity_dict(entity_list):
entity_dict = {}
for entity in entity_list:
entity_uri = entity["uri"]
entity_dict[entity_uri] = entity
return entity_dict
def update_relation_triple_by_noised_entity(triple_list, noised_entity_dict):
noised_triple_list = []
for triple in triple_list:
noised_triple = deepcopy(triple)
head, tail = noised_triple["args"]
noised_head, noised_tail = noised_entity_dict[head["uri"]], noised_entity_dict[tail["uri"]]
noised_triple["args"] = [noised_head, noised_tail]
noised_triple_list.append(noised_triple)
return noised_triple_list
def create_spot_asoc_field(instance_entity_list, instance_triple_list):
instance_spot_asoc_list = []
for entity in instance_entity_list:
instance_spot_asoc = {
"span": entity["text"],
"label": entity["type"],
"asoc": []
}
for triple in instance_triple_list:
if triple["args"][0]["uri"] == entity["uri"]:
asoc_record = [triple["type"], triple["args"][1]["text"]]
instance_spot_asoc["asoc"].append(asoc_record)
instance_spot_asoc_list.append(instance_spot_asoc)
return instance_spot_asoc_list
def create_record_field(instance_spot_asoc_list):
instance_record = "<extra_id_0> "
for instance_spot_asoc in instance_spot_asoc_list:
instance_record += "<extra_id_0> "
instance_record += instance_spot_asoc["label"] + " "
instance_record += "<extra_id_5> "
instance_record += instance_spot_asoc["span"] + " "
if len(instance_spot_asoc["asoc"]) != 0:
for asoc in instance_spot_asoc["asoc"]:
instance_record += "<extra_id_0> "
instance_record += asoc[0] + " "
instance_record += "<extra_id_5> "
instance_record += asoc[1] + " "
instance_record += "<extra_id_1> "
instance_record += "<extra_id_1> "
instance_record += "<extra_id_1>"
return instance_record
# %% create noised record for all
with open(original_all_file) as src, open(noised_all_file, "w") as tgt:
for line in tqdm(src):
instance = json.loads(line)
tokens = instance["tokens"]
entity_list = instance["entity"]
triple_list = instance["relation"]
spot_asoc_list = instance["spot_asoc"]
record = instance["record"]
noised_record_list = []
for _ in range(NOISE_NUM):
# noise entity
noised_entity_list = noise_entity_offset(entity_list, tokens)
noised_entity_list = noise_entity_with_other_entity(noised_entity_list)
noised_entity_list = noise_entity_type(noised_entity_list)
noised_entity_dict = build_entity_dict(noised_entity_list)
# noise triple
noised_triple_list = update_relation_triple_by_noised_entity(triple_list, noised_entity_dict)
noised_triple_list = noise_relation_type(noised_triple_list)
noised_triple_list = noise_triple_num(noised_triple_list, noised_entity_list)
# create noised record
noised_spot_asoc_list = create_spot_asoc_field(noised_entity_list, noised_triple_list)
noised_record = create_record_field(noised_spot_asoc_list)
noised_record_list.append(noised_record)
# remove uir field
for entity in entity_list:
del entity["uri"]
instance["noised_record"] = noised_record_list
json_str = json.dumps(instance)
tgt.write(json_str + "\n")
# %% create train/dev/test data
with open(noised_all_file) as all, open(train_file, "w") as train, open(dev_file, "w") as dev, open(test_file, "w") as test:
for i, line in tqdm(enumerate(all)):
train.write(line)
print("train/dev/test saved.")

View File

@ -0,0 +1,96 @@
import json
import os
import random
import argparse
from collections import OrderedDict
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--source_dir", default="./", type=str)
parser.add_argument("-o", "--output_dir", default="./", type=str)
opt = parser.parse_args()
source_dir = opt.source_dir
output_dir = opt.output_dir
all_file = os.path.join(source_dir, "all.json")
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
partition_file = os.path.join(output_dir, "partition.json")
entity_stat_list = []
relation_stat_list = []
with open(entity_stat_file) as f:
for line in f:
category = json.loads(line)
category[1]["type"] = "entity"
entity_stat_list.append(category)
with open(relation_stat_file) as f:
for line in f:
category = json.loads(line)
category[1]["type"] = "relation"
relation_stat_list.append(category)
all_stat_list = entity_stat_list + relation_stat_list
all_stat_list = sorted(all_stat_list, key=lambda x: len(x[1]["instance_id_list"]))
instance_type_dict = {}
for curr_type, curr_record in tqdm(all_stat_list):
instance_id_list = curr_record["instance_id_list"]
for instance_id in instance_id_list:
if instance_id not in instance_type_dict:
instance_type_dict[instance_id] = set()
instance_type_dict[instance_id].add(curr_type)
def get_visited_type(instance_id_list, instance_type_dict):
visited_type = set()
for i, instance_id in enumerate(instance_id_list):
if i == 0:
visited_type |= instance_type_dict[instance_id]
else:
visited_type &= instance_type_dict[instance_id]
return visited_type
print("Begining partition...")
visited_instance = set()
visited_type = set()
partition = []
empty_set_cnt = 0
duplicated_instance_cnt = 0
for curr_type, curr_record in tqdm(all_stat_list):
category_type = curr_record["type"]
instance_id_list = curr_record["instance_id_list"]
instance_id_set = set(instance_id_list)
instance_id_set = instance_id_set - visited_instance
curr_visited_type = get_visited_type(instance_id_list, instance_type_dict)
if len(instance_id_set) == 0:
if curr_type in visited_type:
continue
else:
non_visited_type = curr_visited_type - visited_type
instance_id_set = set(instance_id_list)
empty_set_cnt += 1
duplicated_instance_cnt += len(instance_id_list)
curr_partition = [curr_type, category_type, list(instance_id_set)]
partition.append(curr_partition)
visited_instance.update(instance_id_set)
visited_type.update(curr_visited_type)
print(f"Empty set rate: {empty_set_cnt / len(all_stat_list)}")
print(f"Duplication rate: {duplicated_instance_cnt / len(instance_type_dict)}")
print("Saving partition...")
with open(partition_file, "w") as f:
for record in partition:
f.write(json.dumps(record)+"\n")

View File

@ -0,0 +1,50 @@
import os
import json
import math
import time
import random
import argparse
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--source_dir", default="./", type=str)
parser.add_argument("-o", "--output_dir", default="./", type=str)
opt = parser.parse_args()
source_dir = opt.source_dir
output_dir = opt.output_dir
all_file = os.path.join(source_dir, "all.json")
match_group_file = os.path.join(output_dir, "match_group.json")
rearrange_all_file = os.path.join(output_dir, "rearrange_all.json")
# %%
print("Loading match group...")
match_group = []
with open(match_group_file) as f:
for line in tqdm(f):
match_group.append(json.loads(line))
# %%
print("Loading instance...")
instance_list = []
with open(all_file) as f:
for line in tqdm(f):
instance_list.append(line)
# %%
print("Rearrange dataset...")
with open(rearrange_all_file, "w") as f:
for edge in tqdm(match_group):
support_id, query_id, _ = edge
support = instance_list[support_id]
query = instance_list[query_id]
f.write(support)
f.write(query)

View File

@ -0,0 +1,150 @@
import json
import os
import random
import argparse
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str)
parser.add_argument("--entity_category_dir", default="entity_category", type=str)
parser.add_argument("--relation_category_dir", default="relation_category", type=str)
parser.add_argument("--task_num", default=10000, type=int)
parser.add_argument("--N", default=5, type=int)
parser.add_argument("--K", default=5, type=int)
parser.add_argument("--Q", default=5, type=int)
opt = parser.parse_args()
data_dir = opt.data_dir
output_dir = opt.output_dir
entity_category_dir = opt.entity_category_dir
relation_category_dir = opt.relation_category_dir
task_num = opt.task_num
N = opt.N
K = opt.K
Q = opt.Q
target_entity_category_dir = os.path.join(output_dir, entity_category_dir)
target_relation_category_dir = os.path.join(output_dir, relation_category_dir)
metainfo_file = relation_instance_dict_file = os.path.join(output_dir, "metainfo.json")
task_file = os.path.join(output_dir, "sampled_task.json")
# %% read instance dict
print("Reading metainfo...")
with open(metainfo_file) as f:
metainfo = json.load(f)
print("Metainfo read.")
print("Loading entity instance dict...")
entity_type_instance_dict = {}
for current_metainfo in tqdm(metainfo["entity"]):
entity_type = current_metainfo["entity_type"]
type_id = current_metainfo["type_id"]
entity_type_file = os.path.join(target_entity_category_dir, str(type_id)+".json")
instance_list = []
with open(entity_type_file) as f:
for line in f:
instance_list.append(line)
entity_type_instance_dict[entity_type] = instance_list
entity_type_list = list(entity_type_instance_dict.keys())
print("Entity instance dict loaded")
print("Loading relation instance dict...")
relation_type_instance_dict = {}
for current_metainfo in tqdm(metainfo["relation"]):
relation_type = current_metainfo["relation_type"]
type_id = current_metainfo["type_id"]
relation_type_file = os.path.join(target_relation_category_dir, str(type_id)+".json")
instance_list = []
with open(relation_type_file) as f:
for line in f:
instance_list.append(line)
relation_type_instance_dict[relation_type] = instance_list
relation_type_list = list(relation_type_instance_dict.keys())
print("Relation instance dict loaded.")
# %% n-way-k-shot sampling
print("Sampling N-Way K-Shot task...")
task_list = []
for i in tqdm(range(task_num//2)):
# sample entity task
target_entity_type_list = random.sample(entity_type_list, N)
task = {
"target_entity_type_list": target_entity_type_list,
"target_relation_type_list": [],
"N": N,
"K": K,
"Q": Q,
"support": None,
"query": None
}
support = []
query = []
for entity_type in target_entity_type_list:
instance_candidates = entity_type_instance_dict[entity_type]
if len(instance_candidates) > K+Q:
sampled_instance_list = random.sample(instance_candidates, K+Q)
else:
sampled_instance_list = random.choices(instance_candidates, k=K+Q)
support.extend(sampled_instance_list[:K])
query.extend(sampled_instance_list[K:])
task["support"] = support
task["query"] = query
task_list.append(task)
# sample relation task
target_relation_type_list = random.sample(relation_type_list, N)
task = {
"target_entity_type_list": [],
"target_relation_type_list": target_relation_type_list,
"N": N,
"K": K,
"Q": Q,
"support": None,
"query": None
}
support = []
query = []
for relation_type in target_relation_type_list:
instance_candidates = relation_type_instance_dict[relation_type]
if len(instance_candidates) > K+Q:
sampled_instance_list = random.sample(instance_candidates, K+Q)
else:
sampled_instance_list = random.choices(instance_candidates, k=K+Q)
support.extend(sampled_instance_list[:K])
query.extend(sampled_instance_list[K:])
task["support"] = support
task["query"] = query
task_list.append(task)
print("Sampling over.")
print("Saving task...")
with open(task_file, "w") as f:
for task in tqdm(task_list):
f.write(json.dumps(task) + "\n")
print("Task saved.")

View File

@ -0,0 +1,64 @@
import json
import os
import random
import argparse
from collections import OrderedDict
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
parser.add_argument("-s", "--source_dir", default="./output", type=str)
parser.add_argument("-o", "--output_dir", default="./output_fewshot", type=str)
opt = parser.parse_args()
data_dir = opt.data_dir
source_dir = opt.source_dir
output_dir = opt.output_dir
all_file = os.path.join(source_dir, "all.json")
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
# %% read data and stat
entity_stat_dict = {}
relation_stat_dict = {}
record_cnt = 0
with open(all_file) as all:
for line in tqdm(all):
if len(line) == 0:
continue
record = json.loads(line)
entity_type_list = record["spot"]
relation_type_list = record["asoc"]
for entity_type in entity_type_list:
if entity_type not in entity_stat_dict:
entity_stat_dict[entity_type] = {
"type_id": len(entity_stat_dict),
"instance_id_list": []
}
entity_stat_dict[entity_type]["instance_id_list"].append(record_cnt)
for relation_type in relation_type_list:
if relation_type not in relation_stat_dict:
relation_stat_dict[relation_type] = {
"type_id": len(relation_stat_dict),
"instance_id_list": []
}
relation_stat_dict[relation_type]["instance_id_list"].append(record_cnt)
record_cnt += 1
with open(entity_stat_file, "w") as f:
json.dump(entity_stat_dict, f)
with open(relation_stat_file, "w") as f:
json.dump(relation_stat_dict, f)

View File

@ -0,0 +1,79 @@
import json
import os
import random
import argparse
from collections import OrderedDict
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--source_dir", default="./", type=str)
parser.add_argument("-o", "--output_dir", default="./", type=str)
opt = parser.parse_args()
source_dir = opt.source_dir
output_dir = opt.output_dir
all_file = os.path.join(source_dir, "all.json")
entity_stat_file = os.path.join(output_dir, "entity_stat.json")
relation_stat_file = os.path.join(output_dir, "relation_stat.json")
instance_label_file = os.path.join(output_dir, "instance_label.json")
# %% read data and stat
instance_label_list = []
entity_stat_dict = {}
relation_stat_dict = {}
print("Stating label...")
record_cnt = 0
with open(all_file) as all:
for line in tqdm(all):
if len(line) == 0:
continue
record = json.loads(line)
entity_type_list = record["spot"]
relation_type_list = record["asoc"]
labels = entity_type_list + relation_type_list
instance_label_list.append((record_cnt, labels))
for entity_type in entity_type_list:
if entity_type not in entity_stat_dict:
entity_stat_dict[entity_type] = {
"type_id": len(entity_stat_dict),
"instance_id_list": []
}
entity_stat_dict[entity_type]["instance_id_list"].append(record_cnt)
for relation_type in relation_type_list:
if relation_type not in relation_stat_dict:
relation_stat_dict[relation_type] = {
"type_id": len(relation_stat_dict),
"instance_id_list": []
}
relation_stat_dict[relation_type]["instance_id_list"].append(record_cnt)
record_cnt += 1
print("Saving entity stat...")
with open(entity_stat_file, "w") as f:
for key, value in tqdm(entity_stat_dict.items()):
f.write(json.dumps([key, value])+"\n")
print("Saving relation stat...")
with open(relation_stat_file, "w") as f:
for key, value in tqdm(relation_stat_dict.items()):
f.write(json.dumps([key, value])+"\n")
print("Saving instance label stat...")
instance_label_list = sorted(instance_label_list, key=lambda x: len(x[1]), reverse=True)
with open(instance_label_file, "w") as f:
for instance_label in tqdm(instance_label_list):
f.write(json.dumps(instance_label)+"\n")

View File

@ -0,0 +1,187 @@
import json
import os
import random
import argparse
from tqdm import tqdm
import pdb
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_dir", default="./final_data5/data_1", type=str)
opt = parser.parse_args()
data_dir = opt.data_dir
output_dir = opt.output_dir
task_file = os.path.join(output_dir, "sampled_task.json")
sampled_all_file = os.path.join(output_dir, "sampled_all.json")
# %% utils
def create_spot_asoc_field(instance_entity_list, instance_triple_list):
instance_spot_asoc_list = []
for entity in instance_entity_list:
instance_spot_asoc = {
"span": entity["text"],
"label": entity["type"],
"asoc": []
}
for triple in instance_triple_list:
if triple["args"][0]["uri"] == entity["uri"]:
asoc_record = [triple["type"], triple["args"][1]["text"]]
instance_spot_asoc["asoc"].append(asoc_record)
instance_spot_asoc_list.append(instance_spot_asoc)
return instance_spot_asoc_list
def create_record_field(instance_spot_asoc_list):
instance_record = "<extra_id_0> "
for instance_spot_asoc in instance_spot_asoc_list:
instance_record += "<extra_id_0> "
instance_record += instance_spot_asoc["label"] + " "
instance_record += "<extra_id_5> "
instance_record += instance_spot_asoc["span"] + " "
if len(instance_spot_asoc["asoc"]) != 0:
for asoc in instance_spot_asoc["asoc"]:
instance_record += "<extra_id_0> "
instance_record += asoc[0] + " "
instance_record += "<extra_id_5> "
instance_record += asoc[1] + " "
instance_record += "<extra_id_1> "
instance_record += "<extra_id_1> "
instance_record += "<extra_id_1>"
return instance_record
def filter_entity_by_entity_type(entity_list, target_entity_type_list):
'''
{"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}
'''
filtered_entity_list = [entity for entity in entity_list if entity["type"] in target_entity_type_list]
return filtered_entity_list
def filter_triple_by_entity_list(triple_list, filtered_entity_list):
'''
{"type": "part of", "args": [{"type": "rocket stage", "offset": [1, 2, 3], "text": "MS-II", "uri": "Q6717655"}, {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}]}
'''
filtered_triple_list = []
for triple in triple_list:
head, tail = triple["args"]
if head in filtered_entity_list and tail in filtered_entity_list:
filtered_triple_list.append(triple)
return filtered_triple_list
def build_target_relation_type_list(filtered_triple_list):
target_relation_type_list = [triple["type"] for triple in filtered_triple_list]
target_relation_type_list = list(set(target_relation_type_list))
return target_relation_type_list
def filter_triple_by_relation_type(triple_list, target_relation_type_list):
'''
{"type": "part of", "args": [{"type": "rocket stage", "offset": [1, 2, 3], "text": "MS-II", "uri": "Q6717655"}, {"type": "rocket stage", "offset": [11, 12, 13], "text": "S-II", "uri": "Q1093699"}]}
'''
filtered_triple_list = [triple for triple in triple_list if triple["type"] in target_relation_type_list]
return filtered_triple_list
def filter_entity_by_triple_list(entity_list, filtered_triple_list):
filtered_entity_list = []
for triple in filtered_triple_list:
head, tail = triple["args"]
filtered_entity_list.append(head)
filtered_entity_list.append(tail)
entity_uri_set = set()
unique_filtered_entity_list = []
for entity in filtered_entity_list:
uri = entity["uri"]
if uri not in entity_uri_set:
entity_uri_set.add(uri)
unique_filtered_entity_list.append(entity)
return unique_filtered_entity_list
def build_target_entity_type_list(filtered_entity_list):
target_entity_type_list = [entity["type"] for entity in filtered_entity_list]
target_entity_type_list = list(set(target_entity_type_list))
return target_entity_type_list
def create_instance(instance_line, target_entity_type_list, target_relation_type_list):
instance = json.loads(instance_line)
entity_list = instance["entity"]
triple_list = instance["relation"]
spot_asoc_list = instance["spot_asoc"]
record = instance["record"]
if len(target_relation_type_list) == 0:
filtered_entity_list = filter_entity_by_entity_type(entity_list, target_entity_type_list)
filtered_triple_list = filter_triple_by_entity_list(triple_list, filtered_entity_list)
current_target_entity_type_list = target_entity_type_list
current_target_relation_type_list = build_target_relation_type_list(filtered_triple_list)
else:
filtered_triple_list = filter_triple_by_relation_type(triple_list, target_relation_type_list)
filtered_entity_list = filter_entity_by_triple_list(entity_list, filtered_triple_list)
current_target_entity_type_list = build_target_entity_type_list(filtered_entity_list)
current_target_relation_type_list = target_relation_type_list
filtered_spot_asoc_list = create_spot_asoc_field(filtered_entity_list, filtered_triple_list)
filtered_record = create_record_field(filtered_spot_asoc_list)
instance["entity"] = filtered_entity_list
instance["relation"] = filtered_triple_list
instance["spot"] = current_target_entity_type_list
instance["asoc"] = current_target_relation_type_list
instance["spot_asoc"] = filtered_spot_asoc_list
instance["record"] = filtered_record
return instance
# %% read task
print("Reading task...")
task_list = []
with open(task_file) as f:
for line in tqdm(f):
task_list.append(line)
print("Task read.")
# %% write to sampled all
print("Changing task format...")
with open(sampled_all_file, "w") as f:
for task_line in tqdm(task_list):
task = json.loads(task_line)
target_entity_type_list = task["target_entity_type_list"]
target_relation_type_list = task["target_relation_type_list"]
support = task["support"]
query = task["query"]
support_instance_list = []
for instance_line in support:
instance = create_instance(instance_line, target_entity_type_list, target_relation_type_list)
support_instance_list.append(instance)
query_instance_list = []
for instance_line in query:
instance = create_instance(instance_line, target_entity_type_list, target_relation_type_list)
query_instance_list.append(instance)
random.shuffle(support_instance_list)
random.shuffle(query_instance_list)
for instance in support_instance_list:
f.write(json.dumps(instance) + "\n")
for instance in query_instance_list:
f.write(json.dumps(instance) + "\n")
print("Task format changed.")

View File

@ -0,0 +1,9 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
for data_format in entity relation event absa
do
python uie_convert.py -format spotasoc -config data_config/${data_format} -output ${data_format}
done
python scripts/data_statistics.py -data converted_data/text2spotasoc/

View File

@ -0,0 +1,53 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
export PYTHONPATH="${PYTHONPATH}:./"
for data_format in entity relation event absa
do
for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio)
do
for seed in 1 2 3 4 5 6 7 8 9 10
do
rm -r converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed}
echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed}
python scripts/sample_data_ratio.py -seed ${seed} \
-src converted_data/text2spotasoc/${data_format}/${dataset} \
-tgt converted_data/text2spotasoc/${data_format}/${dataset}_ratio/seed${seed}
done
done
done
for data_format in entity relation event
do
for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio)
do
for seed in 1 2 3 4 5 6 7 8 9 10
do
rm -r converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
python scripts/sample_data_shot.py -seed ${seed} \
-src converted_data/text2spotasoc/${data_format}/${dataset} \
-tgt converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} \
-task ${data_format}
done
done
done
for data_format in absa
do
for dataset in $(ls converted_data/text2spotasoc/${data_format} | grep -v shot | grep -v ratio)
do
for seed in 1 2 3 4 5 6 7 8 9 10
do
rm -r converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
echo "Convert" converted_data/text2spotasoc/${data_format}/${dataset} "To" converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed}
python scripts/sample_data_shot.py -seed ${seed} \
-src converted_data/text2spotasoc/${data_format}/${dataset} \
-tgt converted_data/text2spotasoc/${data_format}/${dataset}_shot/seed${seed} \
-task relation
done
done
done

View File

@ -0,0 +1,95 @@
import json
import os
import sys
from collections import Counter
import tabulate
def count_line_in_file(filename):
return sum([1 for _ in open(filename)])
def count_record_in_file(filename, key):
counter = Counter()
for line in open(filename):
instance = json.loads(line)
counter.update([key + ' entity'] * len(instance['entity']))
counter.update([key + ' relation'] * len(instance['relation']))
counter.update([key + ' event'] * len(instance['event']))
for event in instance['event']:
counter.update([key + ' role'] * len(event['args']))
return counter
def count_folder(folder_name):
data_map = {
'train': 'train.json',
'val': 'val.json',
'test': 'test.json',
}
intance_counter = {'name': folder_name}
for key, name in data_map.items():
filename = f"{folder_name}/{name}"
if not os.path.exists(filename):
sys.stderr.write(f'[warn] {filename} not exists.\n')
continue
intance_counter[key] = count_line_in_file(filename)
intance_counter.update(count_record_in_file(filename, key))
for key in ['entity', 'relation', 'event']:
filename = f"{folder_name}/{key}.schema"
if not os.path.exists(filename):
sys.stderr.write(f'[warn] {filename} not exists.\n')
intance_counter[key] = 0
continue
intance_counter[key] = len(json.loads(open(filename).readline()))
return intance_counter
def walk_dir(folder_name):
for root, dirs, files in os.walk(folder_name):
for file in dirs:
folder_name = os.path.join(root, file)
if os.path.exists(f"{os.path.join(root, file)}/record.schema"):
yield os.path.join(root, file)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-data')
parser.add_argument('-f', dest='format', default='simple')
options = parser.parse_args()
folder_list = list()
for folder_name in walk_dir(options.data):
if 'shot' in folder_name or 'ratio' in folder_name:
continue
folder_list += [count_folder(folder_name)]
col_name = ['name',
'entity', 'relation', 'event',
'train', 'val', 'test',
'train entity', 'train relation', 'train event', 'train role',
'val entity', 'val relation', 'val event', 'val role',
'test entity', 'test relation', 'test event', 'test role',
]
table = []
for data_info in folder_list:
row = [data_info.get(col, 0) for col in col_name]
table += [row]
table.sort()
print(
tabulate.tabulate(
tabular_data=table,
headers=col_name,
tablefmt=options.format,
)
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import math
import shutil
import random
import argparse
def split_ratio_file(in_filename, out_filename, ratio=0.1, seed=None):
lines = open(in_filename).readlines()
if seed:
random.seed(seed)
random.shuffle(lines)
lines = lines[:math.ceil(len(lines) * ratio)]
with open(out_filename, 'w') as output:
for line in lines:
output.write(line.strip() + '\n')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-src')
parser.add_argument('-tgt')
parser.add_argument('-seed')
options = parser.parse_args()
source_folder = options.src
target_folder = options.tgt
os.makedirs(target_folder, exist_ok=True)
for ratio in [0.01, 0.05, 0.1]:
ratio_folder = os.path.join(target_folder, "%s" % ratio)
os.makedirs(ratio_folder, exist_ok=True)
split_ratio_file(
in_filename=os.path.join(source_folder, 'train.json'),
out_filename=os.path.join(ratio_folder, 'train.json'),
ratio=ratio,
seed=options.seed,
)
for filename in os.listdir(source_folder):
if filename != 'train.json':
shutil.copy(
os.path.join(source_folder, filename),
os.path.join(ratio_folder, filename),
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import shutil
import random
import argparse
from collections import defaultdict
import json
import sys
from universal_ie.record_schema import RecordSchema
def n_shot_smaple(source_filename, target_filename, record_schema,
spot_asoc_key='spot', num_shot=5, min_len=None, seed=None):
train_data = [json.loads(line.strip()) for line in open(source_filename)]
if seed:
random.seed(seed)
random.shuffle(train_data)
# 记录每一句的类别信息
type_to_sentence_dict = defaultdict(list)
for index, instance in enumerate(train_data):
for spot in instance[spot_asoc_key]:
if spot not in record_schema.type_list:
continue
if min_len is not None and len(instance['tokens']) < min_len:
continue
type_to_sentence_dict[spot] += [index]
sampled_data = list()
for entity in type_to_sentence_dict:
if len(type_to_sentence_dict[entity]) < num_shot:
sys.stderr.write(
f'[WARN] {entity} in {source_filename} is less than shot num {num_shot}\n'
)
sampled = type_to_sentence_dict[entity]
else:
sampled = random.sample(type_to_sentence_dict[entity], num_shot)
sampled_data += [train_data[index] for index in sampled]
with open(target_filename, 'w') as output:
for instance in sampled_data:
output.write(json.dumps(instance) + '\n')
return sampled_data
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-src', help='Source Folder Name', required=True)
parser.add_argument('-tgt', help='Target Folder Name, n shot sampled',
required=True)
parser.add_argument('-task', help='N-Shot Task name', required=True,
choices=['entity', 'relation', 'event'])
parser.add_argument('-seed', help='Default is None, no random')
parser.add_argument('-min_len', dest='min_len', help='Default is None', type=int)
options = parser.parse_args()
source_folder = options.src
target_folder = options.tgt
task_name = options.task
if task_name in ['relation']:
spot_asoc_key = 'asoc'
else:
spot_asoc_key = 'spot'
os.makedirs(target_folder, exist_ok=True)
for shot in [1, 5, 10]:
shot_folder = os.path.join(target_folder, "%sshot" % shot)
os.makedirs(shot_folder, exist_ok=True)
n_shot_smaple(
source_filename=os.path.join(source_folder, 'train.json'),
target_filename=os.path.join(shot_folder, 'train.json'),
record_schema=RecordSchema.read_from_file(
os.path.join(source_folder, f'{task_name}.schema'),
),
spot_asoc_key=spot_asoc_key,
num_shot=shot,
seed=options.seed,
min_len=options.min_len
)
for filename in os.listdir(source_folder):
if filename != 'train.json':
shutil.copy(
os.path.join(source_folder, filename),
os.path.join(shot_folder, filename),
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,86 @@
from transformers import AutoTokenizer
import json
import argparse
import tabulate
from universal_ie.record_schema import RecordSchema
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', default='t5-base')
parser.add_argument('-d', '--data', required=True)
parser.add_argument('-s', '--schema', default='event')
options = parser.parse_args()
if "chinese_t5_pegasus" in options.model:
tokenizer = T5PegasusTokenizer.from_pretrained(options.model)
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
else:
tokenizer = AutoTokenizer.from_pretrained(
options.model,
use_fast=False,
mirror='tuna',
)
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<extra_id_0>", "<extra_id_1>"]}
)
folder_path = options.data
schema_file = f"{folder_path}/{options.schema}.schema"
event_schema = RecordSchema.read_from_file(schema_file)
table = list()
for typename in event_schema.type_list:
typename = typename.replace('_', ' ')
after_tokenzied = tokenizer.encode(typename, add_special_tokens=False)
table += [[typename,
after_tokenzied,
tokenizer.convert_ids_to_tokens(after_tokenzied)]]
print(tokenizer)
print(type(tokenizer))
print("===============Event Schema=================")
print(tabulate.tabulate(
table,
headers=['type', 'token id', 'tokenized'],
tablefmt='grid',
))
print("===============Instance=================")
table = list()
for index, instance in enumerate(open(folder_path + "/val.json").readlines()[:10]):
instance = json.loads(instance)
table += [["Text %s" % index] + [instance['text']]]
table += [["Token %s" % index] +
['|'.join(tokenizer.tokenize(instance['text']))]]
if 'entity' in instance:
table += [["Entity %s" % index] +
['|'.join(tokenizer.tokenize(instance['event']))]]
if 'relation' in instance:
table += [["Relation %s" % index] +
['|'.join(tokenizer.tokenize(instance['relation']))]]
if 'event' in instance:
table += [["Event %s" % index] +
['|'.join(tokenizer.tokenize(instance['event']))]]
print(tabulate.tabulate(table, headers=['text', 'event'], tablefmt='grid'))
print("===============Specical Symbol=================")
table = list()
for name in ['<extra_id_0>', '<extra_id_1>']:
table += [[name, tokenizer.encode(name), tokenizer.tokenize(name)]]
print(tabulate.tabulate(
table,
headers=['specical symbol', 'token id', 'tokenized'],
tablefmt='grid'
))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,216 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from collections import Counter
import os
import json
from typing import Dict, List
from tqdm import tqdm
from universal_ie.generation_format.generation_format import GenerationFormat
from universal_ie.generation_format import generation_format_dict
from universal_ie.generation_format.structure_marker import BaseStructureMarker
from universal_ie.dataset import Dataset
from universal_ie.ie_format import Sentence
def convert_graph(
generation_class: GenerationFormat,
output_folder: str,
datasets: Dict[str, List[Sentence]],
language: str = "en",
label_mapper: Dict = None,
):
convertor = generation_class(
structure_maker=BaseStructureMarker(),
language=language,
label_mapper=label_mapper,
)
counter = Counter()
os.makedirs(output_folder, exist_ok=True)
schema_counter = {
"entity": list(),
"relation": list(),
"event": list(),
}
for data_type, instance_list in datasets.items():
with open(os.path.join(output_folder, f"{data_type}.json"), "w") as output:
for instance in tqdm(instance_list):
counter.update([f"{data_type} sent"])
converted_graph = convertor.annonote_graph(
tokens=instance.tokens,
entities=instance.entities,
relations=instance.relations,
events=instance.events,
)
src, tgt, spot_labels, asoc_labels = converted_graph[:4]
spot_asoc = converted_graph[4]
schema_counter["entity"] += instance.entities
schema_counter["relation"] += instance.relations
schema_counter["event"] += instance.events
output.write(
"%s\n"
% json.dumps(
{
"text": src,
"tokens": instance.tokens,
"record": tgt,
"entity": [
entity.to_offset(label_mapper)
for entity in instance.entities
],
"relation": [
relation.to_offset(
ent_label_mapper=label_mapper,
rel_label_mapper=label_mapper,
)
for relation in instance.relations
],
"event": [
event.to_offset(evt_label_mapper=label_mapper)
for event in instance.events
],
"spot": list(spot_labels),
"asoc": list(asoc_labels),
"spot_asoc": spot_asoc,
},
ensure_ascii=False,
)
)
convertor.output_schema(os.path.join(output_folder, "record.schema"))
convertor.get_entity_schema(schema_counter["entity"]).write_to_file(
os.path.join(output_folder, f"entity.schema")
)
convertor.get_relation_schema(schema_counter["relation"]).write_to_file(
os.path.join(output_folder, f"relation.schema")
)
convertor.get_event_schema(schema_counter["event"]).write_to_file(
os.path.join(output_folder, f"event.schema")
)
print(counter)
print(output_folder)
print("==========================")
def convert_to_oneie(output_folder: str, datasets: Dict[str, List[Sentence]]):
os.makedirs(output_folder, exist_ok=True)
counter = Counter()
for data_type, instance_list in datasets.items():
with open(
os.path.join(output_folder, f"{data_type}.oneie.json"), "w"
) as output:
for instance in tqdm(instance_list):
counter.update([f"{data_type} sent"])
entity_mentions = [
{
"id": entity.record_id,
"entity_type": str(entity.label),
"text": entity.span.text,
"start": entity.span.indexes[0],
"end": entity.span.indexes[-1] + 1,
}
for entity in instance.entities
]
relation_mentions = [
{
"id": relation.record_id,
"relation_type": str(relation.label),
"argument": [
{
"entity_id": relation.arg1.record_id,
"text": relation.arg1.span.text,
"role": "Arg-1",
},
{
"entity_id": relation.arg2.record_id,
"text": relation.arg2.span.text,
"role": "Arg-2",
},
],
}
for relation in instance.relations
]
event_mentions = [
{
"id": event.record_id,
"event_type": str(event.label),
"trigger": {
"text": event.span.text,
"start": event.span.indexes[0],
"end": event.span.indexes[-1] + 1,
},
"argument": [
{
"id": arg[1].record_id,
"text": arg[1].span.text,
"role": str(arg[0]),
}
for arg in event.args
],
}
for event in instance.events
]
instance_dict = {
"tokens": instance.tokens,
"sent_id": instance.text_id,
"entity_mentions": entity_mentions,
"relation_mentions": relation_mentions,
"event_mentions": event_mentions,
}
instance_str = json.dumps(instance_dict, ensure_ascii=False)
output.write(f"{instance_str}\n")
print(counter)
print(output_folder)
print("==========================")
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-format", dest="generation_format", default="spotasoc")
parser.add_argument("-config", dest="config", default="data_config/relation")
parser.add_argument("-output", dest="output", default="relation")
options = parser.parse_args()
generation_class = generation_format_dict.get(options.generation_format)
if os.path.isfile(options.config):
config_list = [options.config]
else:
config_list = [
os.path.join(options.config, x) for x in os.listdir(options.config)
]
for filename in config_list:
dataset = Dataset.load_yaml_file(filename)
datasets = dataset.load_dataset()
label_mapper = dataset.mapper
print(label_mapper)
output_name = (
f"converted_data/text2{options.generation_format}/{options.output}/"
+ dataset.name
)
if generation_class:
convert_graph(
generation_class,
output_name,
datasets=datasets,
language=dataset.language,
label_mapper=label_mapper,
)
elif options.generation_format == "oneie":
convert_to_oneie(output_name, datasets=datasets)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from universal_ie.utils import label_format
import yaml
import os
from typing import Dict
import universal_ie.task_format as task_format
class Dataset:
def __init__(self, name: str, path: str, data_class: task_format.TaskFormat, split_dict: Dict, language: str, mapper: Dict, other: Dict = None) -> None:
self.name = name
self.path = path
self.data_class = data_class
self.split_dict = split_dict
self.language = language
self.mapper = mapper
self.other = other
def load_dataset(self):
datasets = {}
for split_name, filename in self.split_dict.items():
datasets[split_name] = self.data_class.load_from_file(
filename=os.path.join(self.path, filename),
language=self.language,
**self.other,
)
return datasets
@staticmethod
def load_yaml_file(yaml_file):
dataset_config = yaml.load(open(yaml_file), Loader=yaml.FullLoader)
if 'mapper' in dataset_config:
mapper = dataset_config['mapper']
for key in mapper:
mapper[key] = label_format(mapper[key])
else:
print(f"{dataset_config['name']} without label mapper.")
mapper = None
return Dataset(
name=dataset_config['name'], # 数据集名字 Name of Dataset
path=dataset_config['path'], # 数据集路径 Path of Dataset
data_class=getattr(task_format, dataset_config['data_class']), # 数据集对应的 Task Format 名字 Raw data loader
split_dict=dataset_config['split'], # 数据集不同划分文件地址 Data Split Path
language=dataset_config['language'], # 数据集语言 Dataset Language
mapper=mapper,
other=dataset_config.get('other', {}),
)

View File

@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from universal_ie.generation_format.text2spotasoc import Text2SpotAsoc
generation_format_dict = {
'spotasoc': Text2SpotAsoc
}

View File

@ -0,0 +1,111 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from typing import List, Dict, Union
from collections import defaultdict
from universal_ie.record_schema import RecordSchema
from universal_ie.generation_format.structure_marker import StructureMarker
from universal_ie.ie_format import Entity, Relation, Event, Label
import abc
class GenerationFormat:
__metaclass__ = abc.ABCMeta
def __init__(self,
structure_maker: StructureMarker,
label_mapper: Dict = None,
language: str = 'en') -> None:
self.structure_maker = structure_maker
self.language = language
self.label_mapper = {} if label_mapper is None else label_mapper
# 用于从数据中统计 Schema
self.record_role_map = defaultdict(set)
def get_label_str(self, label: Label):
return self.label_mapper.get(label.__repr__(), label.__repr__())
@abc.abstractmethod
def annotate_entities(
self, tokens: List[str], entities: List[Entity]): pass
@abc.abstractmethod
def annotate_given_entities(self, tokens: List[str], entities: Union[List[Entity], Entity]): pass
@abc.abstractmethod
def annotate_events(self, tokens: List[str], events: List[Event]): pass
@abc.abstractmethod
def annotate_event_given_predicate(self, tokens: List[str], event: Event): pass
@abc.abstractmethod
def annotate_relation_extraction(self, tokens: List[str],
relations: List[Relation]): pass
def output_schema(self, filename: str):
"""自动导出 Schema 文件
每个 Schema 文件包含三行
- 第一行为 Record 的类别名称列表
- 第二行为 Role 的类别名称列表
- 第三行为 Record-Role 映射关系字典
Args:
filename (str): [description]
"""
record_list = list(self.record_role_map.keys())
role_set = set()
for record in self.record_role_map:
role_set.update(self.record_role_map[record])
self.record_role_map[record] = list(self.record_role_map[record])
role_list = list(role_set)
record_schema = RecordSchema(type_list=record_list,
role_list=role_list,
type_role_dict=self.record_role_map
)
record_schema.write_to_file(filename)
def get_entity_schema(self, entities: List[Entity]):
schema_role_map = set()
for entity in entities:
schema_role_map.add(self.get_label_str(entity.label))
return RecordSchema(
type_list=list(schema_role_map),
role_list=list(),
type_role_dict=dict()
)
def get_relation_schema(self, relations: List[Relation]):
record_role_map = defaultdict(set)
role_set = set()
for relation in relations:
record_role_map[self.get_label_str(relation.label)].add(self.get_label_str(relation.arg1.label))
record_role_map[self.get_label_str(relation.label)].add(self.get_label_str(relation.arg2.label))
for record in record_role_map:
role_set.update(record_role_map[record])
record_role_map[record] = list(self.record_role_map[record])
return RecordSchema(
type_list=list(record_role_map.keys()),
role_list=list(role_set),
type_role_dict=record_role_map
)
def get_event_schema(self, events: List[Event]):
record_role_map = defaultdict(set)
role_set = set()
for event in events:
for role, _ in event.args:
record_role_map[self.get_label_str(event.label)].add(self.get_label_str(role))
for record in record_role_map:
role_set.update(record_role_map[record])
record_role_map[record] = list(self.record_role_map[record])
return RecordSchema(
type_list=list(record_role_map.keys()),
role_list=list(role_set),
type_role_dict=record_role_map
)

View File

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 结构标记符
class StructureMarker:
def __init__(self) -> None:
pass
class BaseStructureMarker(StructureMarker):
def __init__(self) -> None:
super().__init__()
self.sent_start = '<extra_id_0>'
self.sent_end = '<extra_id_1>'
self.record_start = '<extra_id_0>'
self.record_end = '<extra_id_1>'
self.span_start = '<extra_id_0>'
self.span_end = '<extra_id_1>'
self.sep_marker = '<extra_id_2>'
self.source_span_start = '<extra_id_3>'
self.source_span_end = '<extra_id_4>'
self.target_span_start = '<extra_id_5>'
class VisualStructureMarker(StructureMarker):
def __init__(self) -> None:
super().__init__()
self.sent_start = '{'
self.sent_end = '}'
self.record_start = '['
self.record_end = ']'
self.span_start = '('
self.span_end = ')'
self.source_span_start = '<'
self.source_span_end = '>'
self.target_span_start = ':'
self.sep_marker = ':'

View File

@ -0,0 +1,258 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from collections import defaultdict
from typing import List, Dict
from universal_ie.utils import tokens_to_str
from universal_ie.generation_format.generation_format import GenerationFormat, StructureMarker
from universal_ie.ie_format import Entity, Event, Label, Relation, Span
def convert_spot_asoc(spot_asoc_instance, structure_maker):
spot_instance_str_rep_list = list()
for spot in spot_asoc_instance:
spot_str_rep = [
spot['label'],
structure_maker.target_span_start,
spot['span'],
]
for asoc_label, asoc_span in spot.get('asoc', list()):
asoc_str_rep = [
structure_maker.span_start,
asoc_label,
structure_maker.target_span_start,
asoc_span,
structure_maker.span_end,
]
spot_str_rep += [' '.join(asoc_str_rep)]
spot_instance_str_rep_list += [' '.join([
structure_maker.record_start,
' '.join(spot_str_rep),
structure_maker.record_end,
])]
target_text = ' '.join([
structure_maker.sent_start,
' '.join(spot_instance_str_rep_list),
structure_maker.sent_end,
])
return target_text
class Text2SpotAsoc(GenerationFormat):
def __init__(self, structure_maker: StructureMarker, label_mapper: Dict = None, language: str = 'en') -> None:
super().__init__(
structure_maker=structure_maker,
label_mapper=label_mapper,
language=language
)
def annotate_entities(self, tokens: List[str], entities: List[Entity]):
""" Convert Entities
Args:
tokens (List[str]): ['Trump', 'visits', 'China', '.']
entities (List[Entity]): [description]
Returns:
source (str): Trump visits China.
target (str): { [ Person : Trump ] [ Geo-political : China ] }
"""
return self.annonote_graph(tokens=tokens, entities=entities)[:2]
def augment_source_span(self, tokens: List[str], span: Span):
"""[summary]
Args:
tokens (List[str]):
['Trump', 'visits', 'China', '.']
span (Span):
Trump
Returns:
[type]:
['(', 'Trump', ')', 'visits', 'China', '.']
"""
return tokens[:span.indexes[0]] \
+ [self.structure_maker.source_span_start] \
+ tokens[span.indexes[0]:span.indexes[-1] + 1] \
+ [self.structure_maker.source_span_end] \
+ tokens[span.indexes[-1] + 1:]
def annotate_given_entities(self, tokens: List[str], entities):
"""
entityies is List
:param tokens:
['Trump', 'visits', 'China', '.']
:param entities:
['Trump', 'China']
:return:
source (str): ( Trump ) ( China ) : Trump visits China .
target (str): { [ Person : Trump ] [ Geo-political : China ] }
entityies is Entity
:param tokens:
['Trump', 'visits', 'China', '.']
:param entities:
'Trump'
:return:
source (str): < Trump > visits China .
target (str): { [ Person : Trump ] }
"""
if isinstance(entities, list):
entitytokens = []
for entity in entities:
entitytokens += [self.structure_maker.span_start]
entitytokens += entity.span.tokens
entitytokens += [self.structure_maker.span_end]
source_text = tokens_to_str(
entitytokens + [self.structure_maker.sep_marker] + tokens,
language=self.language,
)
_, target_text = self.annonote_graph(tokens=tokens, entities=entities)[:2]
elif isinstance(entities, Entity):
marked_tokens = self.augment_source_span(tokens=tokens, span=entities.span)
source_text = tokens_to_str(marked_tokens, language=self.language)
_, target_text = self.annonote_graph(tokens=tokens, entities=[entities])[:2]
return source_text, target_text
def annotate_events(self, tokens: List[str], events: List[Event]):
"""
:param tokens:
['Trump', 'visits', 'China', '.']
:param events:
:return:
source (str): Trump visits China.
target (str): { [ Visit : visits ( Person : Trump ) ( Location : China ) ] }
"""
return self.annonote_graph(tokens=tokens, events=events)[:2]
def annotate_event_given_predicate(self, tokens: List[str], event: Event):
"""Annotate Event Given Predicate
Args:
tokens (List[str]):
['Trump', 'visits', 'China', '.']
event (Event): Given Predicate
Returns:
[type]: [description]
"""
marked_tokens = self.augment_source_span(tokens=tokens, span=event.span)
source_text = tokens_to_str(marked_tokens, language=self.language)
_, target_text = self.annonote_graph(tokens=tokens, events=[event])[:2]
return source_text, target_text
def annotate_relation_extraction(self,
tokens: List[str],
relations: List[Relation]):
"""
:param tokens:
['Trump', 'visits', 'China', '.']
:param relations:
:return:
source (str): Trump visits China.
target (str): { [ Person : Trump ( Visit : China ) ] }
"""
return self.annonote_graph(tokens=tokens, relations=relations)[:2]
def annotate_entities_and_relation_extraction(self,
tokens: List[str],
entities: List[Entity],
relations: List[Relation]):
"""
:param tokens:
['Trump', 'visits', 'China', '.']
:param relations:
:return:
source (str): Trump visits China.
target (str): { [ Person : Trump ( Visit : China ) ] [ Geo-political : China ] }
"""
return self.annonote_graph(tokens=tokens, entities=entities, relations=relations)[:2]
def annonote_graph(self,
tokens: List[str],
entities: List[Entity] = [],
relations: List[Relation] = [],
events: List[Event] = []):
"""Convert Entity Relation Event to Spot-Assocation Graph
Args:
tokens (List[str]): Token List
entities (List[Entity], optional): Entity List. Defaults to [].
relations (List[Relation], optional): Relation List. Defaults to [].
events (List[Event], optional): Event List. Defaults to [].
Returns:
str: [description]
{
[ Person : Trump ( Visit : China ) ]
[ Visit : visits ( Person : Trump ) ( Location : China ) ]
[ Geo-political : China ]
}
set: Set of Spot
set: Set of Asoc
"""
spot_dict = dict()
asoc_dict = defaultdict(list)
spot_str_rep_list = list()
def add_spot(spot):
spot_key = (tuple(spot.span.indexes), self.get_label_str(spot.label))
spot_dict[spot_key] = spot
if self.get_label_str(spot.label) not in self.record_role_map:
self.record_role_map[self.get_label_str(spot.label)] = set()
def add_asoc(spot, asoc: Label, tail):
spot_key = (tuple(spot.span.indexes), self.get_label_str(spot.label))
asoc_dict[spot_key] += [(tail.span.indexes, tail, self.get_label_str(asoc))]
self.record_role_map[self.get_label_str(spot.label)].add(self.get_label_str(asoc))
for entity in entities:
add_spot(spot=entity)
for relation in relations:
add_spot(spot=relation.arg1)
add_asoc(spot=relation.arg1, asoc=relation.label, tail=relation.arg2)
for event in events:
add_spot(spot=event)
for arg_role, argument in event.args:
add_asoc(spot=event, asoc=arg_role, tail=argument)
spot_asoc_instance = list()
for spot_key in sorted(spot_dict.keys()):
offset, label = spot_key
if spot_dict[spot_key].span.is_empty_span():
continue
spot_instance = {'span': spot_dict[spot_key].span.text,
'label': label,
'asoc': list(),
}
for _, tail, asoc in sorted(asoc_dict.get(spot_key, [])):
if tail.span.is_empty_span():
continue
spot_instance['asoc'] += [(asoc, tail.span.text)]
spot_asoc_instance += [spot_instance]
target_text = convert_spot_asoc(
spot_asoc_instance,
structure_maker=self.structure_maker,
)
source_text = tokens_to_str(tokens, language=self.language)
spot_labels = set([label for _, label in spot_dict.keys()])
asoc_labels = set()
for _, asoc_list in asoc_dict.items():
for _, _, asoc in asoc_list:
asoc_labels.add(asoc)
return source_text, target_text, spot_labels, asoc_labels, spot_asoc_instance

View File

@ -0,0 +1,211 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from abc import abstractmethod
from collections import defaultdict
from typing import List, Union, Tuple
from universal_ie.utils import change_name_using_label_mapper
# All Entity Relation Events are structured records.
# They both have attributes text_id and record_id
# 所有的 Entity Relation Event 都是结构化的记录表示 Record
# 他们都有属性 text_id 和 record_id
class Record:
def __init__(self,
text_id: Union[str, None] = None,
record_id: Union[str, None] = None,
) -> None:
self.text_id = text_id
self.record_id = record_id
@abstractmethod
def to_offset(self):
pass
# Text span
# 连续或者非连续的文本块
class Span:
def __init__(self,
tokens: List[str],
indexes: List[int],
text: str,
text_id: Union[str, None] = None,
) -> None:
self.tokens = tokens
self.indexes = indexes
self.text = text
self.text_id = text_id
def __repr__(self) -> str:
return "[%s](%s)" % (self.text, self.indexes)
@staticmethod
def get_empty_span(text_id: Union[str, None] = None,):
return Span(
tokens=list(),
indexes=list(),
text="",
text_id=text_id
)
def is_empty_span(self):
"""Check is empty span.
Returns:
bool: True, Empty Span; False Non-Empty Span
"""
return len(self.tokens) == 0 and len(self.indexes) == 0
# Label Name
class Label:
def __init__(self, label_name: Union[str, List[str]]) -> None:
self.label_name = label_name
def __repr__(self) -> str:
return self.label_name
def __lt__(self, other):
if not isinstance(other, Label):
return NotImplemented
return self.label_name < other.label_name
# Entity, Span
# 实体,以文本块为核心的一元结构
class Entity(Record):
def __init__(self,
span: Span,
label: Label,
text_id: Union[str, None] = None,
record_id: Union[str, None] = None,
) -> None:
super().__init__(text_id=text_id, record_id=record_id)
self.span = span
self.label = label
def __lt__(self, other):
if not isinstance(other, Entity):
return NotImplemented
return self.span.indexes < other.span.indexes
def __repr__(self) -> str:
return self.span.__repr__() + self.label.__repr__()
def to_offset(self, ent_label_mapper=None):
if self.span.is_empty_span():
# If span is empty, skip entity
return {}
return {'type': change_name_using_label_mapper(self.label.label_name,
ent_label_mapper),
'offset': self.span.indexes,
'text': self.span.text}
# Relation Span Pair
# 关系,以文本块对为核心的二元结构
class Relation(Record):
def __init__(self,
arg1: Entity,
arg2: Entity,
label: Label,
text_id: Union[str, None] = None,
record_id: Union[str, None] = None,
) -> None:
super().__init__(text_id=text_id, record_id=record_id)
self.arg1 = arg1
self.arg2 = arg2
self.label = label
def __repr__(self) -> str:
return self.arg1.__repr__() + self.label.__repr__() + self.arg2.__repr__()
def to_offset(self, rel_label_mapper=None, ent_label_mapper=None):
if self.arg1.span.is_empty_span() or self.arg2.span.is_empty_span():
# If span is empty, skip relation
return {}
return {'type': change_name_using_label_mapper(self.label.label_name,
rel_label_mapper),
'args': [self.arg1.to_offset(ent_label_mapper=ent_label_mapper),
self.arg2.to_offset(ent_label_mapper=ent_label_mapper),
],
}
# Event, Trigger-Mult-Argument
# 事件,以触发词为中心的多元(谓词论元)结构
class Event(Record):
def __init__(self,
span: Span,
label: Label,
args: List[Tuple[Label, Entity]],
text_id: Union[str, None] = None,
record_id: Union[str, None] = None,
) -> None:
super().__init__(text_id=text_id, record_id=record_id)
self.span = span
self.label = label
self.args = args
def __repr__(self) -> str:
return self.span.__repr__() + self.label.__repr__()
def to_offset(self, evt_label_mapper=None):
if self.span.is_empty_span():
# If span is empty, skip relation
return {}
args = list()
for role, arg in self.args:
if arg.span.is_empty_span():
continue
args += [{
'type': change_name_using_label_mapper(
role.label_name,
evt_label_mapper,
),
'offset': arg.span.indexes,
'text': arg.span.text
}]
return {'type': change_name_using_label_mapper(self.label.label_name,
evt_label_mapper),
'offset': self.span.indexes,
'text': self.span.text,
'args': args}
class Sentence:
def __init__(self,
tokens: List[str],
entities: List[Entity] = None,
relations: List[Relation] = None,
events: List[Event] = None,
text_id: Union[str, None] = None,
) -> None:
self.tokens = tokens
self.entities = entities or list()
self.relations = relations or list()
self.events = events or list()
self.text_id = text_id
def count_entity_without_relation(self):
entity_set = set()
entity_counter = defaultdict(int)
for entity in self.entities:
entity_set.add((tuple(entity.span.indexes), entity.label.label_name))
for relation in self.relations:
entity1 = (tuple(relation.arg1.span.indexes), relation.arg1.label.label_name)
entity2 = (tuple(relation.arg2.span.indexes), relation.arg2.label.label_name)
entity_counter[entity1] += 1
entity_counter[entity2] += 1
entity_set.remove(entity1) if entity1 in entity_set else None
entity_set.remove(entity2) if entity2 in entity_set else None
overlap_entity = sum([1 if v > 1 else 0 for k, v in entity_counter.items()])
return {'entity': len(self.entities),
'entity_without_relation': len(entity_set),
'overlap_entity': overlap_entity,
}

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
from collections import defaultdict
from typing import List
class RecordSchema:
def __init__(self, type_list, role_list, type_role_dict):
self.type_list = type_list
self.role_list = role_list
self.type_role_dict = type_role_dict
@staticmethod
def read_from_file(filename):
lines = open(filename).readlines()
type_list = json.loads(lines[0])
role_list = json.loads(lines[1])
type_role_dict = json.loads(lines[2])
return RecordSchema(type_list, role_list, type_role_dict)
def write_to_file(self, filename):
with open(filename, 'w') as output:
output.write(json.dumps(self.type_list, ensure_ascii=False) + '\n')
output.write(json.dumps(self.role_list, ensure_ascii=False) + '\n')
output.write(json.dumps(self.type_role_dict, ensure_ascii=False) + '\n')
def merge_schema(schema_list: List[RecordSchema]):
type_set = set()
role_set = set()
type_role_dict = defaultdict(list)
for schema in schema_list:
for type_name in schema.type_list:
type_set.add(type_name)
for role_name in schema.role_list:
role_set.add(role_name)
for type_name in schema.type_role_dict:
type_role_dict[type_name] += schema.type_role_dict[type_name]
for type_name in type_role_dict:
type_role_dict[type_name] = list(set(type_role_dict[type_name]))
return RecordSchema(type_list=list(type_set),
role_list=list(role_set),
type_role_dict=type_role_dict,
)

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from universal_ie.task_format.task_format import TaskFormat
from universal_ie.task_format.oneie import OneIEEvent
from universal_ie.task_format.jointer import JointER
from universal_ie.task_format.mrc_ner import MRCNER
from universal_ie.task_format.absa import ABSA
from universal_ie.task_format.spannet import Spannet
from universal_ie.task_format.casie import CASIE
from universal_ie.task_format.cols import (
TokenTagCols,
I2b2Conll,
TagTokenCols,
TokenTagJson,
CoNLL03,
)

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
from typing import List
from universal_ie.utils import tokens_to_str, change_ptb_token_back
from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span
from universal_ie.task_format.task_format import TaskFormat
class ABSA(TaskFormat):
""" Aspect-Based Sentiment Analysis Data format at https://github.com/yhcc/BARTABSA."""
def __init__(self, sentence_json, language='en'):
super().__init__(
language=language
)
self.tokens = sentence_json['words']
for index in range(len(self.tokens)):
self.tokens[index] = change_ptb_token_back(self.tokens[index])
if self.tokens is None:
print('[sentence without tokens]:', sentence_json)
exit(1)
self.aspects = sentence_json['aspects']
self.opinions = sentence_json['opinions']
def generate_instance(self):
entities = dict()
relations = list()
entity_map = dict()
for aspect, opinion in zip(self.aspects, self.opinions):
aspect_span = (aspect['from'], aspect['to'])
opinion_span = (opinion['from'], opinion['to'])
if aspect_span not in entity_map:
tokens = self.tokens[aspect_span[0]:aspect_span[1]]
entities[aspect_span] = Entity(
span=Span(
tokens=tokens,
indexes=list(range(aspect_span[0], aspect_span[1])),
text=tokens_to_str(tokens, language=self.language),
),
label=Label('aspect')
)
if opinion_span not in entity_map:
tokens = self.tokens[opinion_span[0]:opinion_span[1]]
entities[opinion_span] = Entity(
span=Span(
tokens=tokens,
indexes=list(range(opinion_span[0], opinion_span[1])),
text=tokens_to_str(tokens, language=self.language),
),
label=Label('opinion')
)
relations += [Relation(
arg1=entities[aspect_span],
arg2=entities[opinion_span],
label=Label(aspect['polarity']),
)]
return Sentence(
tokens=self.tokens,
entities=entities.values(),
relations=relations,
)
@staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
sentence_list = list()
raw_instance_list = json.load(open(filename))
print(f"{filename}: {len(raw_instance_list)}")
for instance in raw_instance_list:
instance = ABSA(
sentence_json=instance,
language=language
).generate_instance()
sentence_list += [instance]
return sentence_list

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,505 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from collections import Counter
import json
from typing import List, Optional, Tuple, Set
from tqdm import tqdm
from universal_ie.task_format.task_format import TaskFormat
from universal_ie.utils import tokens_to_str
from universal_ie.ie_format import Entity, Label, Sentence, Span
# https://github.com/allenai/allennlp/blob/main/allennlp/data/dataset_readers/dataset_utils/span_utils.py
# ### Start Code
def bio_tags_to_spans(
tag_sequence: List[str], classes_to_ignore: List[str] = None
) -> List[Tuple[str, Tuple[int, int]]]:
"""
Given a sequence corresponding to BIO tags, extracts spans.
Spans are inclusive and can be of zero length, representing a single word span.
Ill-formed spans are also included (i.e those which do not start with a "B-LABEL"),
as otherwise it is possible to get a perfect precision score whilst still predicting
ill-formed spans in addition to the correct spans. This function works properly when
the spans are unlabeled (i.e., your labels are simply "B", "I", and "O").
# Parameters
tag_sequence : `List[str]`, required.
The integer class labels for a sequence.
classes_to_ignore : `List[str]`, optional (default = `None`).
A list of string class labels `excluding` the bio tag
which should be ignored when extracting spans.
# Returns
spans : `List[TypedStringSpan]`
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
Note that the label `does not` contain any BIO tag prefixes.
"""
classes_to_ignore = classes_to_ignore or []
spans: Set[Tuple[str, Tuple[int, int]]] = set()
span_start = 0
span_end = 0
active_conll_tag = None
for index, string_tag in enumerate(tag_sequence):
# Actual BIO tag.
bio_tag = string_tag[0]
if bio_tag not in ["B", "I", "O"]:
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
conll_tag = string_tag[2:]
if bio_tag == "O" or conll_tag in classes_to_ignore:
# The span has ended.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
active_conll_tag = None
# We don't care about tags we are
# told to ignore, so we do nothing.
continue
elif bio_tag == "B":
# We are entering a new span; reset indices
# and active tag to new span.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
active_conll_tag = conll_tag
span_start = index
span_end = index
elif bio_tag == "I" and conll_tag == active_conll_tag:
# We're inside a span.
span_end += 1
else:
# This is the case the bio label is an "I", but either:
# 1) the span hasn't started - i.e. an ill formed span.
# 2) The span is an I tag for a different conll annotation.
# We'll process the previous span if it exists, but also
# include this span. This is important, because otherwise,
# a model may get a perfect F1 score whilst still including
# false positive ill-formed spans.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
active_conll_tag = conll_tag
span_start = index
span_end = index
# Last token might have been a part of a valid span.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
return list(spans)
def _iob1_start_of_chunk(
prev_bio_tag: Optional[str],
prev_conll_tag: Optional[str],
curr_bio_tag: str,
curr_conll_tag: str,
) -> bool:
if curr_bio_tag == "B":
return True
if curr_bio_tag == "I" and prev_bio_tag == "O":
return True
if curr_bio_tag != "O" and prev_conll_tag != curr_conll_tag:
return True
return False
def iob1_tags_to_spans(
tag_sequence: List[str], classes_to_ignore: List[str] = None
) -> List[Tuple[str, Tuple[int, int]]]:
"""
Given a sequence corresponding to IOB1 tags, extracts spans.
Spans are inclusive and can be of zero length, representing a single word span.
Ill-formed spans are also included (i.e., those where "B-LABEL" is not preceded
by "I-LABEL" or "B-LABEL").
# Parameters
tag_sequence : `List[str]`, required.
The integer class labels for a sequence.
classes_to_ignore : `List[str]`, optional (default = `None`).
A list of string class labels `excluding` the bio tag
which should be ignored when extracting spans.
# Returns
spans : `List[TypedStringSpan]`
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
Note that the label `does not` contain any BIO tag prefixes.
"""
classes_to_ignore = classes_to_ignore or []
spans: Set[Tuple[str, Tuple[int, int]]] = set()
span_start = 0
span_end = 0
active_conll_tag = None
prev_bio_tag = None
prev_conll_tag = None
for index, string_tag in enumerate(tag_sequence):
curr_bio_tag = string_tag[0]
curr_conll_tag = string_tag[2:]
if curr_bio_tag not in ["B", "I", "O"]:
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
if curr_bio_tag == "O" or curr_conll_tag in classes_to_ignore:
# The span has ended.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
active_conll_tag = None
elif _iob1_start_of_chunk(prev_bio_tag, prev_conll_tag, curr_bio_tag, curr_conll_tag):
# We are entering a new span; reset indices
# and active tag to new span.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
active_conll_tag = curr_conll_tag
span_start = index
span_end = index
else:
# bio_tag == "I" and curr_conll_tag == active_conll_tag
# We're continuing a span.
span_end += 1
prev_bio_tag = string_tag[0]
prev_conll_tag = string_tag[2:]
# Last token might have been a part of a valid span.
if active_conll_tag is not None:
spans.add((active_conll_tag, (span_start, span_end)))
return list(spans)
def bmes_tags_to_spans(
tag_sequence: List[str], classes_to_ignore: List[str] = None
) -> List[Tuple[str, Tuple[int, int]]]:
"""
Given a sequence corresponding to BMES tags, extracts spans.
Spans are inclusive and can be of zero length, representing a single word span.
Ill-formed spans are also included (i.e those which do not start with a "B-LABEL"),
as otherwise it is possible to get a perfect precision score whilst still predicting
ill-formed spans in addition to the correct spans.
This function works properly when the spans are unlabeled (i.e., your labels are
simply "B", "M", "E" and "S").
# Parameters
tag_sequence : `List[str]`, required.
The integer class labels for a sequence.
classes_to_ignore : `List[str]`, optional (default = `None`).
A list of string class labels `excluding` the bio tag
which should be ignored when extracting spans.
# Returns
spans : `List[TypedStringSpan]`
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
Note that the label `does not` contain any BIO tag prefixes.
"""
def extract_bmes_tag_label(text):
bmes_tag = text[0]
label = text[2:]
return bmes_tag, label
spans: List[Tuple[str, List[int]]] = []
prev_bmes_tag: Optional[str] = None
for index, tag in enumerate(tag_sequence):
bmes_tag, label = extract_bmes_tag_label(tag)
if bmes_tag in ("B", "S"):
# Regardless of tag, we start a new span when reaching B & S.
spans.append((label, [index, index]))
elif bmes_tag in ("M", "E") and prev_bmes_tag in ("B", "M") and spans[-1][0] == label:
# Only expand the span if
# 1. Valid transition: B/M -> M/E.
# 2. Matched label.
spans[-1][1][1] = index
else:
# Best effort split for invalid span.
spans.append((label, [index, index]))
# update previous BMES tag.
prev_bmes_tag = bmes_tag
classes_to_ignore = classes_to_ignore or []
return [
# to tuple.
(span[0], (span[1][0], span[1][1]))
for span in spans
if span[0] not in classes_to_ignore
]
def bioul_tags_to_spans(
tag_sequence: List[str], classes_to_ignore: List[str] = None
) -> List[Tuple[str, Tuple[int, int]]]:
"""
Given a sequence corresponding to BIOUL tags, extracts spans.
Spans are inclusive and can be of zero length, representing a single word span.
Ill-formed spans are not allowed and will raise `InvalidTagSequence`.
This function works properly when the spans are unlabeled (i.e., your labels are
simply "B", "I", "O", "U", and "L").
# Parameters
tag_sequence : `List[str]`, required.
The tag sequence encoded in BIOUL, e.g. ["B-PER", "L-PER", "O"].
classes_to_ignore : `List[str]`, optional (default = `None`).
A list of string class labels `excluding` the bio tag
which should be ignored when extracting spans.
# Returns
spans : `List[TypedStringSpan]`
The typed, extracted spans from the sequence, in the format (label, (span_start, span_end)).
"""
spans = []
classes_to_ignore = classes_to_ignore or []
index = 0
while index < len(tag_sequence):
label = tag_sequence[index]
if label[0] == "U":
spans.append((label.partition("-")[2], (index, index)))
elif label[0] == "B":
start = index
while label[0] != "L":
index += 1
if index >= len(tag_sequence):
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
# raise InvalidTagSequence(tag_sequence)
label = tag_sequence[index]
if not (label[0] == "I" or label[0] == "L"):
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
# raise InvalidTagSequence(tag_sequence)
spans.append((label.partition("-")[2], (start, index)))
else:
if label != "O":
raise RuntimeError('Invalid tag sequence %s' % tag_sequence)
# raise InvalidTagSequence(tag_sequence)
index += 1
return [span for span in spans if span[0] not in classes_to_ignore]
def bmeso_tags_to_spans(
tag_sequence: List[str], classes_to_ignore: List[str] = None
) -> List[Tuple[str, Tuple[int, int]]]:
"""
bmeso -> bioul
B = Beginning
I/M = Inside / Middle
L/E = Last / End
O = Outside
U/W/S = Unit-length / Whole / Singleton
"""
new_tag = list()
for label in tag_sequence:
if label[0] == 'M':
new_tag += ['I-' + label.partition("-")[2]]
elif label[0] == 'E':
new_tag += ['L-' + label.partition("-")[2]]
elif label[0] == 'S':
new_tag += ['U-' + label.partition("-")[2]]
else:
new_tag += [label]
return bioul_tags_to_spans(tag_sequence=new_tag, classes_to_ignore=classes_to_ignore)
def bieso_tags_to_spans(
tag_sequence: List[str], classes_to_ignore: List[str] = None
) -> List[Tuple[str, Tuple[int, int]]]:
"""
bmeso -> bioul
B = Beginning
I/M = Inside / Middle
L/E = Last / End
O = Outside
U/W/S = Unit-length / Whole / Singleton
"""
new_tag = list()
for label in tag_sequence:
if label[0] == 'E':
new_tag += ['L-' + label.partition("-")[2]]
elif label[0] == 'S':
new_tag += ['U-' + label.partition("-")[2]]
else:
new_tag += [label]
return bioul_tags_to_spans(tag_sequence=new_tag, classes_to_ignore=classes_to_ignore)
# ### End Code
_tagging_span_function = {
'bioul': bioul_tags_to_spans,
'bmes': bmes_tags_to_spans,
'bio': bio_tags_to_spans,
'iob1': iob1_tags_to_spans,
'bmeso': bmeso_tags_to_spans,
'bieso': bieso_tags_to_spans,
}
class Cols(TaskFormat):
def __init__(self, tokens: List[str], spans: List[Tuple[Tuple[int, int], str]], language='en', instance_id=None) -> None:
super().__init__(
language=language
)
self.instance_id = instance_id
self.tokens = tokens
self.spans = spans
def generate_instance(self):
entities = list()
for span_index, span in enumerate(self.spans):
tokens = self.tokens[span['start']: span['end'] + 1]
indexes = list(range(span['start'], span['end'] + 1))
entities += [
Entity(
span=Span(
tokens=tokens,
indexes=indexes,
text=tokens_to_str(tokens, language=self.language),
text_id=self.instance_id
),
label=Label(span['type']),
text_id=self.instance_id,
record_id=self.instance_id + "#%s" % span_index if self.instance_id else None)
]
return Sentence(tokens=self.tokens,
entities=entities,
text_id=self.instance_id)
@staticmethod
def generate_sentence(filename):
sentence = list()
with open(filename) as fin:
for line in fin:
if line.strip() == '':
if len(sentence) != 0:
yield sentence
sentence = list()
else:
sentence += [line.strip().split()]
if len(sentence) != 0:
yield sentence
class TokenTagCols(Cols):
@staticmethod
def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]:
sentence_list = list()
counter = Counter()
for rows in tqdm(Cols.generate_sentence(filename)):
tokens = [token[0] for token in rows]
ner = [token[1] for token in rows]
spans = _tagging_span_function[tagging](ner)
spans = list(filter(lambda x: x[0] != "", spans))
spans = [
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
for span in spans
]
sentence = Cols(
tokens=tokens,
spans=spans,
language=language,
)
counter.update(['token'] * len(tokens))
counter.update(['sentence'])
counter.update(['span'] * len(spans))
sentence_list += [sentence.generate_instance()]
print(filename, counter)
return sentence_list
class TagTokenCols(Cols):
@staticmethod
def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]:
sentence_list = list()
counter = Counter()
for rows in tqdm(Cols.generate_sentence(filename)):
tokens = [token[1] for token in rows]
ner = [token[0] for token in rows]
spans = _tagging_span_function[tagging](ner)
spans = [
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
for span in spans
]
sentence = Cols(
tokens=tokens,
spans=spans,
language=language,
)
counter.update(['token'] * len(tokens))
counter.update(['sentence'])
counter.update(['span'] * len(spans))
sentence_list += [sentence.generate_instance()]
print(filename, counter)
return sentence_list
class TokenTagJson(Cols):
@staticmethod
def load_from_file(filename, language='en', tagging='bio') -> List[Sentence]:
sentence_list = list()
counter = Counter()
for line in open(filename):
instance = json.loads(line.strip())
tokens = instance['tokens']
ner = instance['ner_tags']
spans = _tagging_span_function[tagging](ner)
spans = list(filter(lambda x: x[0] != "", spans))
spans = [
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
for span in spans
]
sentence = Cols(
tokens=tokens,
spans=spans,
language=language,
)
counter.update(['token'] * len(tokens))
counter.update(['sentence'])
counter.update(['span'] * len(spans))
sentence_list += [sentence.generate_instance()]
print(filename, counter)
return sentence_list
class I2b2Conll(Cols):
@staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
sentence_list = list()
counter = Counter()
for rows in tqdm(Cols.generate_sentence(filename)):
tokens = [token[0] for token in rows]
ner = [token[4] for token in rows]
spans = bio_tags_to_spans(ner)
spans = [
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
for span in spans
]
sentence = Cols(
tokens=tokens,
spans=spans,
language=language,
)
counter.update(['token'] * len(tokens))
counter.update(['sentence'])
counter.update(['span'] * len(spans))
sentence_list += [sentence.generate_instance()]
print(filename, counter)
return sentence_list
class CoNLL03(Cols):
@staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
sentence_list = list()
counter = Counter()
for rows in tqdm(Cols.generate_sentence(filename)):
if rows[0][0] == '-DOCSTART-':
continue
tokens = [token[0] for token in rows]
ner = [token[3] for token in rows]
spans = iob1_tags_to_spans(ner)
spans = [
{'start': span[1][0], 'end': span[1][1], 'type': span[0]}
for span in spans
]
sentence = Cols(
tokens=tokens,
spans=spans,
language=language,
)
counter.update(['token'] * len(tokens))
counter.update(['sentence'])
counter.update(['span'] * len(spans))
sentence_list += [sentence.generate_instance()]
print(filename, counter)
return sentence_list
if __name__ == "__main__":
pass

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
from typing import List
from universal_ie.utils import tokens_to_str, change_ptb_token_back
from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span
from universal_ie.task_format.task_format import TaskFormat
class JointER(TaskFormat):
""" Joint Entity Relation Data format at https://github.com/yubowen-ph/JointER"""
def __init__(self, sentence_json, language='en'):
super().__init__(
language=language
)
self.tokens = sentence_json['tokens']
for index in range(len(self.tokens)):
self.tokens[index] = change_ptb_token_back(self.tokens[index])
if self.tokens is None:
print('[sentence without tokens]:', sentence_json)
exit(1)
self.spo_list = sentence_json['spo_list']
self.spo_details = sentence_json['spo_details']
self.pos_tags = sentence_json['pos_tags']
def generate_instance(self):
entities = dict()
relations = dict()
entity_map = dict()
for spo_index, spo in enumerate(self.spo_details):
s_s, s_e, s_t = spo[0], spo[1], spo[2]
tokens = self.tokens[s_s: s_e]
indexes = list(range(s_s, s_e))
if (s_s, s_e, s_t) not in entity_map:
entities[(s_s, s_e, s_t)] = Entity(
span=Span(
tokens=tokens,
indexes=indexes,
text=tokens_to_str(tokens, language=self.language),
),
label=Label(s_t)
)
o_s, o_e, o_t = spo[4], spo[5], spo[6]
tokens = self.tokens[o_s: o_e]
indexes = list(range(o_s, o_e))
if (o_s, o_e, o_t) not in entity_map:
entities[(o_s, o_e, o_t)] = Entity(
span=Span(
tokens=tokens,
indexes=indexes,
text=tokens_to_str(tokens, language=self.language),
),
label=Label(o_t)
)
relations[spo_index] = Relation(
arg1=entities[(s_s, s_e, s_t)],
arg2=entities[(o_s, o_e, o_t)],
label=Label(spo[3]),
)
return Sentence(
tokens=self.tokens,
entities=entities.values(),
relations=relations.values(),
)
@staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
sentence_list = list()
raw_instance_list = json.load(open(filename))
print(f"{filename}: {len(raw_instance_list)}")
for instance in raw_instance_list:
instance = JointER(
sentence_json=instance,
language=language
).generate_instance()
sentence_list += [instance]
return sentence_list

View File

@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
from collections import Counter, defaultdict
from typing import Dict, List
from universal_ie.task_format.spannet import Spannet
from universal_ie.ie_format import Sentence
class MRCNER(Spannet):
""" MRC NER format at https://github.com/ShannonAI/mrc-for-flat-nested-ner"""
id_template = "%s#%s"
def __init__(self, instance_json: Dict, language='en'):
super().__init__(
instance_json=instance_json,
language=language
)
@ staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
counter = Counter()
dataset = defaultdict(dict)
with open(filename) as fin:
for instance in json.load(fin):
counter.update(['label sentence'])
key, _ = instance['qas_id'].split('.')
dataset[key]['tokens'] = instance['context'].split()
if 'spans' not in dataset[key]:
dataset[key]['spans'] = list()
for start, end in zip(instance['start_position'],
instance['end_position']):
dataset[key]['spans'] += [{
'start': start,
'end': end,
'type': instance['entity_label']
}]
counter.update(['span'])
sentence_list = list()
for sentence_id, sentence in dataset.items():
counter.update(['sentence'])
mrc_instance = MRCNER(
instance_json={
'tokens': sentence['tokens'],
'span_list': sentence['spans'],
'id': sentence_id
},
language=language
)
sentence_list += [mrc_instance.generate_instance()]
print(filename, counter)
return sentence_list

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
from typing import List
from universal_ie.task_format.task_format import TaskFormat
from universal_ie.utils import tokens_to_str
from universal_ie.ie_format import Entity, Event, Label, Sentence, Span
"""
{
"doc_id": "AFP_ENG_20030427.0118",
"sent_id": "AFP_ENG_20030427.0118-1",
"tokens": ["A", "Pakistani", "court", "in", "central", "Punjab", "province", "has", "sentenced", "a", "Christian", "man", "to", "life", "imprisonment", "for", "a", "blasphemy", "conviction", ",", "police", "said", "Sunday", "."], "pieces": ["A", "Pakistani", "court", "in", "central", "Punjab", "province", "has", "sentenced", "a", "Christian", "man", "to", "life", "imprisonment", "for", "a", "b", "##lasp", "##hem", "##y", "conviction", ",", "police", "said", "Sunday", "."],
"token_lens": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1],
"sentence": "A Pakistani court in central Punjab province has sentenced a Christian man to life imprisonment for a blasphemy conviction, police said Sunday.",
"entity_mentions": [
{"id": "AFP_ENG_20030427.0118-E15-53", "text": "Pakistani", "entity_type": "GPE", "mention_type": "NAM", "entity_subtype": "Nation", "start": 1, "end": 2},
{"id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "entity_type": "ORG", "mention_type": "NOM", "entity_subtype": "Government", "start": 2, "end": 3},
{"id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "entity_type": "LOC", "mention_type": "NOM", "entity_subtype": "Region-General", "start": 6, "end": 7},
{"id": "AFP_ENG_20030427.0118-E27-48", "text": "Christian", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Group", "start": 10, "end": 11},
{"id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Individual", "start": 11, "end": 12},
{"id": "AFP_ENG_20030427.0118-E39-56", "text": "police", "entity_type": "PER", "mention_type": "NOM", "entity_subtype": "Group", "start": 20, "end": 21}],
"relation_mentions": [
{"id": "AFP_ENG_20030427.0118-R1-1", "relation_type": "GEN-AFF", "relation_subtype": "GEN-AFF:Citizen-Resident-Religion-Ethnicity",
"arguments": [
{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Arg-1"},
{"entity_id": "AFP_ENG_20030427.0118-E27-48", "text": "Christian", "role": "Arg-2"}
]
},
{"id": "AFP_ENG_20030427.0118-R3-1", "relation_type": "PART-WHOLE", "relation_subtype": "PART-WHOLE:Subsidiary",
"arguments": [
{"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Arg-1"},
{"entity_id": "AFP_ENG_20030427.0118-E15-53", "text": "Pakistani", "role": "Arg-2"}
]
},
{"id": "AFP_ENG_20030427.0118-R4-1", "relation_type": "GEN-AFF", "relation_subtype": "GEN-AFF:Org-Location",
"arguments": [
{"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Arg-1"},
{"entity_id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "role": "Arg-2"}
]
}
],
"event_mentions": [
{"id": "AFP_ENG_20030427.0118-EV1-1", "event_type": "Justice:Sentence",
"trigger": {"text": "sentenced", "start": 8, "end": 9},
"arguments": [
{"entity_id": "AFP_ENG_20030427.0118-E35-52", "text": "court", "role": "Adjudicator"},
{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Defendant"},
{"entity_id": "AFP_ENG_20030427.0118-E37-54", "text": "province", "role": "Place"}
]},
{"id": "AFP_ENG_20030427.0118-EV2-1", "event_type": "Justice:Convict",
"trigger": {"text": "conviction", "start": 18, "end": 19},
"arguments": [{"entity_id": "AFP_ENG_20030427.0118-E38-55", "text": "man", "role": "Defendant"}
]}
]}
"""
class OneIEEvent(TaskFormat):
def __init__(self, doc_json, language='en'):
super().__init__(
language=language
)
self.doc_id = doc_json['doc_id']
self.sent_id = doc_json['sent_id']
self.tokens = doc_json['tokens']
self.entities = doc_json['entity_mentions']
self.relations = doc_json['relation_mentions']
self.events = doc_json['event_mentions']
def generate_instance(self):
events = dict()
entities = dict()
for span_index, span in enumerate(self.entities):
tokens = self.tokens[span['start']: span['end']]
indexes = list(range(span['start'], span['end']))
entities[span['id']] = Entity(
span=Span(
tokens=tokens,
indexes=indexes,
text=tokens_to_str(tokens, language=self.language),
text_id=self.sent_id
),
label=Label(span['entity_type']),
text_id=self.sent_id,
record_id=span['id']
)
for event_index, event in enumerate(self.events):
start = event['trigger']['start']
end = event['trigger']['end']
tokens = self.tokens[start:end]
indexes = list(range(start, end))
events[event['id']] = Event(
span=Span(
tokens=tokens,
indexes=indexes,
text=tokens_to_str(tokens, language=self.language),
text_id=self.sent_id
),
label=Label(event['event_type']),
args=[(Label(x['role']), entities[x['entity_id']])
for x in event['arguments']],
text_id=self.sent_id,
record_id=event['id']
)
return Sentence(
tokens=self.tokens,
entities=list(),
relations=list(),
events=events.values(),
text_id=self.sent_id
)
@staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
sentence_list = list()
with open(filename) as fin:
for line in fin:
instance = OneIEEvent(
json.loads(line.strip()),
language=language
).generate_instance()
sentence_list += [instance]
return sentence_list

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from collections import Counter
import json
from typing import List, Dict
from universal_ie.task_format.task_format import TaskFormat
from universal_ie.utils import change_ptb_token_back, tokens_to_str
from universal_ie.ie_format import Entity, Label, Relation, Sentence, Span
from tqdm import tqdm
class Spannet(TaskFormat):
"""
{
"tokens": ["An", "art", "exhibit", "at", "the", "Hakawati", "Theatre",
"in", "Arab", "east", "Jerusalem", "was", "a", "series",
"of", "portraits", "of", "Palestinians", "killed", "in",
"the", "rebellion", "."],
"span_pair_list": [
{"type": "OrgBased_In", "head": 0, "tail": 2}
],
"span_list": [
{"type": "Org", "start": 5, "end": 6},
{"type": "Other", "start": 8, "end": 8},
{"type": "Loc", "start": 10, "end": 10},
{"type": "Other", "start": 17, "end": 17}
]
}
"""
def __init__(self, instance_json: Dict, language='en') -> None:
super().__init__(
language=language
)
self.tokens = change_ptb_token_back(instance_json['tokens'])
self.span_list = instance_json.get('span_list', [])
self.span_pair_list = instance_json.get('span_pair_list', [])
self.instance_id = instance_json.get('id', None)
def generate_instance(self):
entities = list()
relations = list()
for span_index, span in enumerate(self.span_list):
tokens = self.tokens[span['start']: span['end'] + 1]
indexes = list(range(span['start'], span['end'] + 1))
entities += [
Entity(
span=Span(
tokens=tokens,
indexes=indexes,
text=tokens_to_str(tokens, language=self.language),
text_id=self.instance_id
),
label=Label(span['type']),
text_id=self.instance_id,
record_id=self.instance_id + "#%s" % span_index if self.instance_id else None)
]
for spanpair_index, span_pair in enumerate(self.span_pair_list):
relations += [
Relation(
arg1=entities[span_pair['head']],
arg2=entities[span_pair['tail']],
label=Label(span_pair['type']),
text_id=self.instance_id,
record_id=self.instance_id + "##%s" % spanpair_index if self.instance_id else None
)
]
return Sentence(tokens=self.tokens,
entities=entities,
relations=relations,
text_id=self.instance_id)
@staticmethod
def load_from_file(filename, language='en') -> List[Sentence]:
sentence_list = list()
counter = Counter()
with open(filename) as fin:
for line in tqdm(fin):
spannet = Spannet(
json.loads(line.strip()),
language=language
)
instance = spannet.generate_instance()
sentence_list += [instance]
counter.update(['sentence'])
counter.update(['span'] * len(spannet.span_list))
print(filename, counter)
return sentence_list

View File

@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import abc
class TaskFormat:
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def __init__(self, language='en'):
self.language = language
@abc.abstractmethod
def generate_instance(self):
pass
@staticmethod
@abc.abstractmethod
def load_from_file(filename, language='en'):
pass

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from typing import List
import os
import sys
global_mislabel_log = set()
def tokens_to_str(tokens: List[str], language: str = 'en') -> str:
if language == 'en':
return ' '.join(tokens)
elif language == 'zh':
return ''.join(tokens)
else:
raise NotImplementedError('Language %s not supported' % language)
def label_format(s):
import re
def uncamelize(s):
re_outer = re.compile(r'([^A-Z ])([A-Z])')
re_inner = re.compile(r'\b[A-Z]+(?=[A-Z][a-z])')
sub = re_inner.sub(r'\g<0> ', re_outer.sub(r'\1 \2', s)).lower()
return sub
def remove(s):
return s.replace("_", " ").replace("-", " ").replace(".", " ")
s = remove(uncamelize(s)).split()
if len(s) > 1 and s[0] == s[1]:
s = s[1:]
return " ".join(s)
def load_dict_ini_file(filename):
print("Warning: `load_dict_ini_file` is deprecated.")
if not os.path.exists(filename):
sys.stderr.write(f'[warning] cannot load label mapper from {filename}\n')
return {}
mapper = dict()
for line in open(filename):
key, value = line.strip().split('=')
mapper[key] = label_format(value)
return mapper
def change_ptb_token_back(token):
"""将 PTBTokenized 的 Token 转换会原始字符串
Args:
token (str): PTBTokenize 后的 Token 字符串
Returns:
str: 原始 Token 字符串
"""
ptb_token_map = {
'``': '"',
"''": '"',
'-LRB-': '(',
'-RRB-': ')',
'-LSB-': '[',
'-RSB-': ']',
'-LCB-': '{',
'-RCB-': '}',
}
for ptb_token, raw_token in ptb_token_map.items():
if token == ptb_token:
return raw_token
return token
def change_name_using_label_mapper(label_name, label_mapper):
if label_mapper is None or len(label_mapper) == 0:
return label_name
if label_name not in label_mapper:
print(f"{label_name} not found in mapper")
global global_mislabel_log
if label_name not in global_mislabel_log:
global_mislabel_log.add(label_name)
return label_mapper.get(label_name, label_name)

View File

@ -0,0 +1,29 @@
FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04
LABEL maintainer="Yaojie Lu"
LABEL repository="uie"
RUN apt update && \
apt install -y bash \
build-essential \
git \
curl \
ca-certificates \
python3 \
python3-pip && \
rm -rf /var/lib/apt/lists
WORKDIR /pre_env
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir mkl && \
python3 -m pip install --no-cache-dir torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
RUN git clone https://github.com/NVIDIA/apex
RUN cd apex && \
python3 setup.py install && \
python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
COPY ./requirements.txt .
RUN python3 -m pip install -r ./requirements.txt
CMD ["/bin/bash"]

View File

@ -0,0 +1,53 @@
# 数据说明
``` json
{
"text": "MULTAN , Pakistan , April 27 ( AFP )",
"tokens": ["MULTAN", ",", "Pakistan", ",", "April", "27", "(", "AFP", ")"],
"record": "<extra_id_0> <extra_id_0> geographical social political <extra_id_5> MULTAN <extra_id_0> part whole <extra_id_5> Pakistan <extra_id_1> <extra_id_1> <extra_id_0> geographical social political <extra_id_5> Pakistan <extra_id_1> <extra_id_0> organization <extra_id_5> AFP <extra_id_1> <extra_id_1>",
"entity": [
{"type": "geographical social political", "offset": [0], "text": "MULTAN"},
{"type": "geographical social political", "offset": [2], "text": "Pakistan"},
{"type": "organization", "offset": [7], "text": "AFP"}
],
"relation": [
{
"type": "part whole",
"args": [
{"type": "geographical social political", "offset": [0], "text": "MULTAN"},
{"type": "geographical social political", "offset": [2], "text": "Pakistan"}
]
}
],
"event": [],
"spot": ["geographical social political", "organization"],
"asoc": ["part whole"],
"spot_asoc": [
{
"span": "MULTAN",
"label": "geographical social political",
"asoc": [["part whole", "Pakistan"]]
},
{
"span": "Pakistan",
"label": "geographical social political", "asoc": []
},
{
"span": "AFP", "label": "organization", "asoc": []
}
],
"task": 'record'
}
```
- task: `seq`, `record`, `t5mlm`
- mlm 只要求有 Text
- seq 只要求有 Record
- record 要求有 Text-record 数据
- 若无,默认为 Text-record 数据
- spot、asoc
- 文本中的正例类别
- spot_asoc
- record 结构表示
- entity relation event
- Offset 标准答案,用于模型验证。

View File

@ -0,0 +1,41 @@
# Tools for UIE
### Evaluate Model Performance
验证模型性能 (eval_extraction.py)
```text
$ python scripts/eval_extraction.py -h
usage: eval_extraction.py [-h] [-g GOLD_FOLDER] [-p PRED_FOLDER [PRED_FOLDER ...]] [-v] [-w] [-m] [-case]
optional arguments:
-h, --help show this help message and exit
-g GOLD_FOLDER Golden Dataset folder
-p PRED_FOLDER [PRED_FOLDER ...]
Predicted model folder
-v Show more information during running
-w Write evaluation results to predicted folder
-m Match predicted result multiple times
-case Show case study
```
### Check Offset Mapping Performance
验证回标的性能 (check_offset_map_gold_as_pred.bash)
``` bash
bash scripts/check_offset_map_gold_as_pred.bash <data-folder> <map-config>
```
### Convert SEL to Record
将结构化表达式转换成 Record 结构 (sel2record.py)
``` text
$ python scripts/sel2record.py -h
usage: sel2record.py [-h] [-g GOLD_FOLDER] [-p PRED_FOLDER [PRED_FOLDER ...]] [-c MAP_CONFIG] [-d DECODING] [-v]
optional arguments:
-h, --help show this help message and exit
-g GOLD_FOLDER 标准答案Gold文件夹
-p PRED_FOLDER [PRED_FOLDER ...]
多个不同的预测Pred文件夹
-c MAP_CONFIG, --config MAP_CONFIG
Offset 匹配策略的配置文件
-d DECODING 使用 SpotAsoc 结构的解析器进行结构表达式解析
-v, --verbose 打印更详细的日志信息
```

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

163
metaretriever/inference.py Normal file
View File

@ -0,0 +1,163 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
import re
from tqdm import tqdm
import transformers as huggingface_transformers
from uie.extraction.record_schema import RecordSchema
from uie.sel2record.record import MapConfig
from uie.extraction.scorer import *
from uie.sel2record.sel2record import SEL2Record
import math
import os
split_bracket = re.compile(r"\s*<extra_id_\d>\s*")
special_to_remove = {'<pad>', '</s>'}
def read_json_file(file_name):
return [json.loads(line) for line in open(file_name)]
def schema_to_ssi(schema: RecordSchema):
ssi = "<spot> " + "<spot> ".join(sorted(schema.type_list))
ssi += "<asoc> " + "<asoc> ".join(sorted(schema.role_list))
ssi += "<extra_id_2> "
return ssi
def post_processing(x):
for special in special_to_remove:
x = x.replace(special, '')
return x.strip()
class HuggingfacePredictor:
def __init__(self, model_path, schema_file, max_source_length=256, max_target_length=192) -> None:
self._tokenizer = huggingface_transformers.T5TokenizerFast.from_pretrained(
model_path)
self._model = huggingface_transformers.T5ForConditionalGeneration.from_pretrained(
model_path)
self._model.cuda()
self._schema = RecordSchema.read_from_file(schema_file)
self._ssi = schema_to_ssi(self._schema)
self._max_source_length = max_source_length
self._max_target_length = max_target_length
def predict(self, text):
text = [self._ssi + x for x in text]
inputs = self._tokenizer(
text, padding=True, return_tensors='pt').to(self._model.device)
inputs['input_ids'] = inputs['input_ids'][:, :self._max_source_length]
inputs['attention_mask'] = inputs['attention_mask'][:,
:self._max_source_length]
result = self._model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=self._max_target_length,
)
return self._tokenizer.batch_decode(result, skip_special_tokens=False, clean_up_tokenization_spaces=False)
task_dict = {
'entity': EntityScorer,
'relation': RelationScorer,
'event': EventScorer,
}
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--data', '-d', default='data/text2spotasoc/absa/14lap')
parser.add_argument(
'--model', '-m', default='./models/uie_n10_21_50w_absa_14lap')
parser.add_argument('--max_source_length', default=256, type=int)
parser.add_argument('--max_target_length', default=192, type=int)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('-c', '--config', dest='map_config',
help='Offset Re-mapping Config',
default='config/offset_map/closest_offset_en.yaml')
parser.add_argument('--decoding', default='spotasoc')
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--match_mode', default='normal',
choices=['set', 'normal', 'multimatch'])
options = parser.parse_args()
data_folder = options.data
model_path = options.model
predictor = HuggingfacePredictor(
model_path=model_path,
schema_file=f"{data_folder}/record.schema",
max_source_length=options.max_source_length,
max_target_length=options.max_target_length,
)
map_config = MapConfig.load_from_yaml(options.map_config)
schema_dict = SEL2Record.load_schema_dict(data_folder)
sel2record = SEL2Record(
schema_dict=schema_dict,
decoding_schema=options.decoding,
map_config=map_config,
)
for split, split_name in [('val', 'eval'), ('test', 'test')]:
gold_filename = f"{data_folder}/{split}.json"
text_list = [x['text'] for x in read_json_file(gold_filename)]
token_list = [x['tokens'] for x in read_json_file(gold_filename)]
batch_num = math.ceil(len(text_list) / options.batch_size)
predict = list()
for index in tqdm(range(batch_num)):
start = index * options.batch_size
end = index * options.batch_size + options.batch_size
pred_seq2seq = predictor.predict(text_list[start: end])
pred_seq2seq = [post_processing(x) for x in pred_seq2seq]
predict += pred_seq2seq
records = list()
for p, text, tokens in zip(predict, text_list, token_list):
r = sel2record.sel2record(pred=p, text=text, tokens=tokens)
records += [r]
results = dict()
for task, scorer in task_dict.items():
gold_list = [x[task] for x in read_json_file(gold_filename)]
pred_list = [x[task] for x in records]
gold_instance_list = scorer.load_gold_list(gold_list)
pred_instance_list = scorer.load_pred_list(pred_list)
sub_results = scorer.eval_instance_list(
gold_instance_list=gold_instance_list,
pred_instance_list=pred_instance_list,
verbose=options.verbose,
match_mode=options.match_mode,
)
results.update(sub_results)
with open(os.path.join(options.model, f'{split_name}_preds_record.txt'), 'w') as output:
for record in records:
output.write(f'{json.dumps(record)}\n')
with open(os.path.join(options.model, f'{split_name}_preds_seq2seq.txt'), 'w') as output:
for pred in predict:
output.write(f'{pred}\n')
with open(os.path.join(options.model, f'{split_name}_results.txt'), 'w') as output:
for key, value in results.items():
output.write(f'{split_name}_{key}={value}\n')
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,46 @@
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt
def read_tensorboard_data(tensorboard_log_path, val_name):
ea = event_accumulator.EventAccumulator(tensorboard_log_path)
ea.Reload()
print("All scalers:")
print(ea.scalars.Keys())
val = ea.scalars.Items(val_name)
return val
def plot(vals, val_names, max_step=None):
plt.figure()
for val, val_name in zip(vals, val_names):
x = [i.step for i in val]
y = [i.value for i in val]
if max_step is not None:
x = [i for i in x if i < max_step]
y = y[:len(x)]
plt.plot(x, y, label=val_name)
plt.xlabel("step")
plt.ylabel("loss")
plt.legend()
plt.show()
if __name__ == "__main__":
refine_uie_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654419004.dsw32050-7df697f45c-6bwkm.44438.0"
refine_t5_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654361305.g64h07153.cloud.sqa.nt12.129194.0"
uie_t5_tensorboard_log_path = "tensorboard_logs/events.out.tfevents.1654275965.eflops-common033255085104.NT12.106708.0"
val_name = "train/loss"
refine_uie_val = read_tensorboard_data(refine_uie_tensorboard_log_path, val_name)
refine_t5_val = read_tensorboard_data(refine_t5_tensorboard_log_path, val_name)
uie_t5_val = read_tensorboard_data(uie_t5_tensorboard_log_path, val_name)
vals = [refine_uie_val, refine_t5_val, uie_t5_val]
val_names = ["refine_uie_loss", "refine_t5_loss", "uie_t5_loss"]
max_step = 20000
plot(vals, val_names, max_step=max_step)

View File

@ -0,0 +1,173 @@
absl-py==1.0.0
altair==4.2.0
anyio==3.5.0
anytree==2.8.0
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
asgiref==3.5.0
astor==0.8.1
asttokens==2.0.5
attrs==21.4.0
autopep8==1.6.0
backcall==0.2.0
backports.zoneinfo==0.2.1
base58==2.1.1
black==21.12b0
bleach==4.1.0
blinker==1.4
cachetools==4.2.4
certifi==2021.10.8
cffi==1.15.0
charset-normalizer==2.0.10
click==8.0.3
colorama==0.4.4
colorlog==6.6.0
commonmark==0.9.1
conllu==4.4.1
cycler==0.11.0
dataclasses==0.6
datasets==1.9.0
debugpy==1.5.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.4
elasticsearch==7.16.3
entrypoints==0.3
executing==0.8.2
faiss-cpu==1.7.2
fastapi==0.74.1
filelock==3.0.12
fire==0.4.0
fonttools==4.28.5
fsspec==2022.1.0
future==0.18.2
git-python==1.0.3
gitdb==4.0.9
GitPython==3.1.26
google-auth==2.3.3
google-auth-oauthlib==0.4.6
googleapis-common-protos==1.54.0
grpcio==1.43.0
h11==0.13.0
h5py==3.6.0
huggingface-hub==0.0.8
idna==3.3
importlib-metadata==4.10.1
importlib-resources==5.4.0
iniconfig==1.1.1
ipykernel==6.7.0
ipython
ipython-genutils==0.2.0
ipywidgets==7.6.5
jedi==0.18.1
jieba==0.42.1
Jinja2==3.0.3
joblib==1.1.0
jsonschema==4.4.0
jupyter-client==7.1.1
jupyter-core==4.9.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.2
kiwisolver==1.3.2
Markdown==3.3.6
MarkupSafe==2.0.1
matplotlib==3.5.1
matplotlib-inline==0.1.3
mistune==0.8.4
multiprocess==0.70.12.2
mypy-extensions==0.4.3
nbclient==0.5.10
nbconvert==6.4.0
nbformat==5.1.3
nest-asyncio==1.5.4
nltk==3.6.7
notebook==6.4.7
numpy==1.19.5
oauthlib==3.1.1
packaging==21.3
pandas==1.3.5
pandocfilters==1.5.0
parso==0.8.3
pathspec==0.9.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.0.0
platformdirs==2.4.1
pluggy==1.0.0
portalocker==2.3.2
prometheus-client==0.12.0
promise==2.3
prompt-toolkit==3.0.24
protobuf==3.19.3
psutil==5.9.0
ptyprocess==0.7.0
pure-eval==0.2.1
py==1.11.0
pyarrow==4.0.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle==2.8.0
pycparser==2.21
pydantic==1.9.0
pydeck==0.7.1
Pygments==2.11.2
Pympler==1.0.1
pyparsing==3.0.6
pyrsistent==0.18.1
pytest==6.2.5
python-dateutil==2.8.2
pytz==2021.3
pytz-deprecation-shim==0.1.0.post0
PyYAML==6.0
pyzmq==22.3.0
regex==2022.1.18
requests==2.27.1
requests-oauthlib==1.3.0
rich==9.8.2
rouge-score==0.0.4
rsa==4.8
sacrebleu==1.4.14
sacremoses==0.0.47
scikit-learn==1.0.2
scipy==1.7.3
Send2Trash==1.8.0
sentencepiece==0.1.96
seqeval==1.2.2
six==1.16.0
smmap==5.0.0
sniffio==1.2.0
stack-data==0.1.4
starlette==0.17.1
streamlit==1.4.0
tabulate==0.8.9
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow-datasets==4.4.0
tensorflow-metadata==1.6.0
termcolor==1.1.0
terminado==0.12.1
testpath==0.5.0
threadpoolctl==3.0.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.3
toolz==0.11.2
tornado==6.1
tqdm==4.62.3
traitlets==5.1.1
transformers==4.6.1
typing-extensions==3.10.0.2
tzdata==2021.5
tzlocal==4.1
urllib3==1.26.8
uvicorn==0.17.5
validators==0.18.2
watchdog==2.1.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.2
widgetsnbextension==3.5.2
xxhash==2.0.2
zipp==3.7.0
learn2learn

View File

@ -0,0 +1,87 @@
device="0"
model_path=""
data_folder=data/text2spotasoc/absa/14lap
task_name="meta"
batch=16
decoding_format='spotasoc'
beam_size=1
map_config=config/offset_map/closest_offset_en.yaml
export PYTHONPATH="${PYTHONPATH}:./"
OPTS=$(getopt -o b:d:m:i:t:co:f:e: --long batch:,device:,model:,data:,task:constraint_decoding,output:,format:,map_config:,extra_cmd:, -n 'parse-options' -- "$@")
if [ $? != 0 ]; then
echo "Failed parsing options." >&2
exit 1
fi
eval set -- "$OPTS"
while true; do
case "$1" in
-b | --batch) batch="$2"
shift 2 ;;
-d | --device) device="$2"
shift 2 ;;
-m | --model) model_path="$2"
shift 2 ;;
-i | --data) data_folder="$2"
shift 2 ;;
-t | --task) task_name="$2"
shift 2 ;;
-c | --constraint_decoding) constraint_decoding="--constraint_decoding"
shift ;;
-o | --output) output_dir="$2"
shift 2 ;;
-f | --format) decoding_format="$2"
shift 2 ;;
-e | --extra_cmd) extra_cmd="$2"
shift 2 ;;
--beam) beam_size="$2"
shift 2 ;;
--map_config) map_config="$2"
shift 2 ;;
--)
shift
break
;;
*)
echo "$1" not recognize.
exit
;;
esac
done
echo "Extra CMD: " "${extra_cmd}"
if [[ ${output_dir} == "" ]]
then
output_dir=${model_path}_eval
if [[ ${constraint_decoding} != "" ]]
then
output_dir=${output_dir}_CD
fi
fi
CUDA_VISIBLE_DEVICES=${device} python3 run_seq2seq.py \
--use_fast_tokenizer=True \
--max_source_length=${max_source_length:-"256"} \
--max_target_length=${max_target_length:-"192"} \
--do_eval --do_predict --task=record --predict_with_generate \
--validation_file=${data_folder}/val.json \
--test_file=${data_folder}/test.json \
--record_schema=${data_folder}/record.schema \
--model_name_or_path=${model_path} \
--output_dir=${output_dir} \
--source_prefix="${task_name}: " \
--no_remove_unused_columns \
--num_beams=${beam_size} \
${constraint_decoding} ${extra_cmd} \
--per_device_eval_batch_size=${batch} \
--decoding_format ${decoding_format}
python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config}
python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"}
echo "Output Dir:" ${output_dir}

View File

@ -0,0 +1,779 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import logging
import os
import sys
from dataclasses import dataclass, field
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from typing import Optional
import numpy as np
from datasets import load_dataset
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
default_data_collator,
set_seed
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from uie.extraction import constants
from uie.extraction.record_schema import RecordSchema
from uie.extraction.predict_parser import decoding_format_dict
from uie.extraction.extraction_metrics import get_extract_metrics
from uie.extraction.noiser.spot_asoc_noiser import SpotAsocNoiser
from uie.extraction.dataset_processer import PrefixGenerator
from uie.seq2seq.constrained_seq2seq import (
ConstraintSeq2SeqTrainingArguments,
ConstraintSeq2SeqTrainer,
OriginalConstraintSeq2SeqTrainer,
UIEPretrainConstraintSeq2SeqTrainer,
UIEFinetuneConstraintSeq2SeqTrainer,
MetaPretrainConstraintSeq2SeqTrainer,
MetaFinetuneConstraintSeq2SeqTrainer,
)
from uie.seq2seq.data_collator import (
DataCollatorForMetaSeq2Seq,
DynamicSSIGenerator,
)
from uie.seq2seq.features import RecordFeature
from uie.seq2seq.model import PromptSeq2SeqTransformer
from uie.seq2seq.noise_record import create_noised_record
import pdb
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
from_checkpoint: bool = field(
default=False, metadata={"help": "Whether load from checkpoint to continue learning"}
)
load_config_only: bool = field(
default=False, metadata={"help": "Whether load model config only from checkpoint"}
)
use_prompt_tuning_model: bool = field(
default=False, metadata={"help": "Whether use prompt tuning model"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task: str = field(
default="summarization",
metadata={
"help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating "
"pegasus) or translation (or translation_{xx}_to_{yy})."
},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
text_column: Optional[str] = field(
default='text',
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
record_column: Optional[str] = field(
default='record',
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_prefix_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum prefix length."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
meta_negative: int = field(
default=-1, metadata={"help": "Negative Schema Number in Training."}
)
ordered_prompt: bool = field(
default=True,
metadata={
"help": "Whether to sort the spot prompt and asoc prompt or not."
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
decoding_format: str = field(
default='tree',
metadata={"help": "Decoding Format, valid in %s" % decoding_format_dict.keys()}
)
record_schema: str = field(
default=None, metadata={"help": "The input event schema file."}
)
spot_noise: float = field(
default=0., metadata={"help": "The noise rate of null spot."}
)
asoc_noise: float = field(
default=0., metadata={"help": "The noise rate of null asoc."}
)
meta_positive_rate: float = field(
default=1., metadata={"help": "The keep rate of positive spot."}
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConstraintSeq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
logger.info("Options:")
logger.info(model_args)
logger.info(data_args)
logger.info(training_args)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
# second column for the summaries (unless you specify column names for this with the `text_column` and
# `record_column` arguments).
# For translation, only JSON files are supported, with one field named "translation" containing two keys for the
# source and target languages (unless you adapt what follows).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if training_args.do_eval and data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if training_args.do_predict and data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
logger.info(data_files)
datasets = load_dataset("uie_json.py", data_files=data_files, block_size=(10<<22))
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
logger.info(datasets)
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
logger.info("Load Config: %s" % model_args.config_name if model_args.config_name else model_args.model_name_or_path)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
config.max_length = data_args.max_target_length
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
to_remove_token_list = list()
if tokenizer.bos_token:
to_remove_token_list += [tokenizer.bos_token]
if tokenizer.eos_token:
to_remove_token_list += [tokenizer.eos_token]
if tokenizer.pad_token:
to_remove_token_list += [tokenizer.pad_token]
if model_args.use_prompt_tuning_model:
MODEL = PromptSeq2SeqTransformer
else:
MODEL = AutoModelForSeq2SeqLM
if model_args.load_config_only:
model = MODEL.from_config(config)
else:
model = MODEL.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
mirror='tuna',
)
if training_args.do_train:
to_add_special_token = list()
for special_token in [constants.type_start, constants.type_end, constants.text_start, constants.span_start, constants.spot_prompt, constants.asoc_prompt]:
if special_token not in tokenizer.get_vocab():
to_add_special_token += [special_token]
tokenizer.add_special_tokens(
{"additional_special_tokens": tokenizer.special_tokens_map_extended['additional_special_tokens'] + to_add_special_token}
)
model.resize_token_embeddings(len(tokenizer))
logger.info(tokenizer)
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if data_args.record_schema and os.path.exists(data_args.record_schema):
record_schema = RecordSchema.read_from_file(data_args.record_schema)
else:
record_schema = None
if data_args.source_prefix is not None:
if data_args.source_prefix == 'schema':
prefix = PrefixGenerator.get_schema_prefix(schema=record_schema)
elif data_args.source_prefix.startswith('meta'):
prefix = ""
else:
prefix = data_args.source_prefix
else:
prefix = ""
logger.info(f"Prefix: {prefix}")
logger.info(f"Prefix Length: {len(tokenizer.tokenize(prefix))}")
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = datasets["train"].column_names
elif training_args.do_eval:
column_names = datasets["validation"].column_names
elif training_args.do_predict:
column_names = datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
# To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use
# them all).
text_column = data_args.text_column
record_column = data_args.record_column
logger.info('Using src: %s and tgt: %s' % (text_column, record_column))
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.error(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
def preprocess_function(examples):
inputs = examples[text_column]
targets = examples[record_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
model_inputs["text"] = inputs
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(_label if _label != tokenizer.pad_token_id else -100) for _label in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
# set noised record inputs
noised_record_list = []
for idx, noised_record in enumerate(examples["noised_record"]):
if noised_record is None:
tokens = examples["tokens"][idx]
entity_list = examples["entity"][idx]
triple_list = examples["relation"][idx]
event_list = examples["event"][idx]
noised_record = create_noised_record(tokens, entity_list, triple_list, event_list)
noised_record_list.append(noised_record)
model_inputs["noised_record"] = noised_record_list
# model_inputs["noised_record"] = examples["noised_record"]
# others
model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids'])
if data_args.source_prefix is not None and data_args.source_prefix.startswith('meta'):
model_inputs['spots'] = examples['spot']
model_inputs['asocs'] = examples['asoc']
model_inputs['spot_asoc'] = examples['spot_asoc']
# sample_prompt=True for Finetune and Pretrain
model_inputs['sample_prompt'] = [True] * len(model_inputs['input_ids'])
return model_inputs
def preprocess_function_eval(examples):
model_inputs = preprocess_function(examples)
# sample_prompt=False for evaluation
model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids'])
return model_inputs
def postprocess_text(x_str):
# Clean `bos` `eos` `pad` for cleaned text
for to_remove_token in to_remove_token_list:
x_str = x_str.replace(to_remove_token, '')
return x_str.strip()
logger.info("Start Data Preprocessing ...")
if training_args.do_train:
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
features=RecordFeature,
)
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
features=RecordFeature,
)
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
features=RecordFeature,
)
logger.info("End Data Preprocessing ...")
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
elif data_args.source_prefix.startswith('meta'):
if data_args.spot_noise > 0 or data_args.asoc_noise > 0:
if data_args.decoding_format == 'spotasoc':
spot_asoc_nosier = SpotAsocNoiser(
spot_noise_ratio=data_args.spot_noise,
asoc_noise_ratio=data_args.asoc_noise,
null_span=constants.null_span,
)
else:
raise NotImplementedError(
"decoding_format `spotasoc` is not implemented."
)
else:
spot_asoc_nosier = None
data_collator = DataCollatorForMetaSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
max_length=data_args.max_source_length,
max_prefix_length=data_args.max_prefix_length,
max_target_length=data_args.max_target_length,
negative_sampler=DynamicSSIGenerator(
tokenizer=tokenizer,
schema=record_schema,
positive_rate=data_args.meta_positive_rate,
negative=data_args.meta_negative,
ordered_prompt=data_args.ordered_prompt,
),
spot_asoc_nosier=spot_asoc_nosier,
decoding_format=data_args.decoding_format,
)
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=False, clean_up_tokenization_spaces=False)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False, clean_up_tokenization_spaces=False)
decoded_preds = [postprocess_text(x) for x in decoded_preds]
decoded_labels = [postprocess_text(x) for x in decoded_labels]
result = get_extract_metrics(
pred_lns=decoded_preds,
tgt_lns=decoded_labels,
label_constraint=record_schema,
decoding_format=data_args.decoding_format,
)
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
# Initialize our Trainer
if training_args.trainer_type == "uie_pretrain":
TRAINER = UIEPretrainConstraintSeq2SeqTrainer
elif training_args.trainer_type == "uie_finetune":
TRAINER = UIEFinetuneConstraintSeq2SeqTrainer
elif training_args.trainer_type == "meta_pretrain":
TRAINER = MetaPretrainConstraintSeq2SeqTrainer
elif training_args.trainer_type == "meta_finetune":
TRAINER = MetaFinetuneConstraintSeq2SeqTrainer
else:
TRAINER = OriginalConstraintSeq2SeqTrainer
trainer = TRAINER(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
decoding_type_schema=record_schema,
decoding_format=data_args.decoding_format,
source_prefix=prefix,
task=data_args.task,
)
# Training
if training_args.do_train:
if model_args.from_checkpoint:
if last_checkpoint is not None:
checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path):
checkpoint = model_args.model_name_or_path
else:
checkpoint = None
else:
checkpoint = None
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}
eval_results = trainer.predict(
eval_dataset,
metric_key_prefix="eval",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
if training_args.predict_with_generate:
eval_preds = tokenizer.batch_decode(
eval_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
eval_preds = [postprocess_text(pred) for pred in eval_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "eval_preds_seq2seq.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(eval_preds))
if training_args.do_predict:
logger.info("*** Test ***")
test_results = trainer.predict(
test_dataset,
metric_key_prefix="test",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_test_result_file, "w") as writer:
logger.info("***** Test results *****")
for key, value in sorted(test_metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
test_preds = [postprocess_text(pred) for pred in test_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,104 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
export batch_size="16"
export model_name=uie-base-en
export data_name=absa/14lap
export task_name="meta"
export decoding_format='spotasoc'
source scripts/function_code.bash
for index in $(seq 1 ${run_time}); do
if [[ ! ${output_dir} ]]
then
output_dir=${model_folder}_run${index}
echo "output_dir is not provided so create it automatically: ${output_dir}"
else
echo "output_dir is provided: ${output_dir}"
fi
if [[ ${verbose} == true ]]
then
stdout_file=/dev/stdout
stderr_file=/dev/stderr
disable_tqdm=False
else
stdout_file=${output_dir}.log
stderr_file=${output_dir}.err
disable_tqdm=True
fi
# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} gdb --args ${run_command} run_seq2seq.py \
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
--do_train ${constraint_decoding} ${fp16} \
--trainer_type=${trainer_type} \
--load_config_only=False \
--use_fast_tokenizer=True \
--ddp_find_unused_parameters=False \
--predict_with_generate \
--evaluation_strategy="no" \
--metric_for_best_model eval_overall-F1 \
--save_strategy="steps" \
--save_steps=10000 \
--save_total_limit 9999999 \
--load_best_model_at_end=False \
--max_source_length="128" \
--max_prefix_length="-1" \
--max_target_length="128" \
--num_train_epochs=${epoch} \
--task=${task_name} \
--train_file=${data_folder}/train.json \
--validation_file=${data_folder}/val.json \
--test_file=${data_folder}/test.json \
--record_schema=${data_folder}/record.schema \
--per_device_train_batch_size=${batch_size} \
--per_device_eval_batch_size=$((batch_size * 4)) \
--output_dir=${output_dir} \
--from_checkpoint=True \
--logging_dir=${output_dir}_log \
--logging_strategy="steps" \
--logging_first_step=True \
--logging_steps=100 \
--model_name_or_path=${model_name} \
--learning_rate=${lr} \
--source_prefix="${task_name}: " \
--lr_scheduler_type=${lr_scheduler} \
--label_smoothing_factor=${label_smoothing} \
--eval_steps ${eval_steps} \
--decoding_format ${decoding_format} \
--warmup_ratio ${warmup_ratio} \
--preprocessing_num_workers=32 \
--dataloader_num_workers=32 \
--meta_negative=10 \
--meta_positive_rate=${positive} \
--skip_memory_metrics \
--no_remove_unused_columns \
--ordered_prompt=${ordered_prompt} \
--save_better_checkpoint=False \
--start_eval_step=${start_eval_step:-"0"} \
--spot_noise=${spot_noise} \
--asoc_noise=${asoc_noise} \
--seed=${seed}${index} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
echo "exit code:" $?
# --max_source_length=${max_source_length:-"128"} \
# --max_prefix_length=${max_prefix_length:-"-1"} \
# --max_target_length=${max_target_length:-"128"} \
# --save_strategy=${evaluation_strategy} \
# --save_total_limit 1 \
# --load_best_model_at_end \
if [[ ${verbose} != true ]]
then
tail -n 200 ${stderr_file}
fi
# echo "Map Config" ${map_config}
# python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config}
# python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"}
# delete all optimizer.pt for saving disk
find ${output_dir}/ | grep -P "optimizer.pt" | xargs rm -rf
done

View File

@ -0,0 +1,107 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
export batch_size="16"
export model_name=uie-base-en
export data_name=absa/14lap
export task_name="meta"
export decoding_format='spotasoc'
source scripts/function_code.bash
for index in $(seq 1 ${run_time}); do
if [[ ! ${output_dir} ]]
then
output_dir=${model_folder}_run${index}
echo "output_dir is not provided so create it automatically: ${output_dir}"
else
echo "output_dir is provided: ${output_dir}"
fi
# output_dir=${model_folder}_run${index}
if [[ ${verbose} == true ]]
then
stdout_file=/dev/stdout
stderr_file=/dev/stderr
disable_tqdm=False
else
stdout_file=${output_dir}.log
stderr_file=${output_dir}.err
disable_tqdm=True
fi
# CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} gdb --args ${run_command} run_seq2seq.py \
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
--do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \
--use_prompt_tuning_model=${use_prompt_tuning_model} \
--trainer_type=${trainer_type} \
--load_config_only=False \
--use_fast_tokenizer=True \
--ddp_find_unused_parameters=False \
--predict_with_generate \
--evaluation_strategy=${evaluation_strategy} \
--save_strategy=${evaluation_strategy} \
--metric_for_best_model eval_overall-F1 \
--save_total_limit 1 \
--load_best_model_at_end \
--max_source_length=${max_source_length:-"256"} \
--max_prefix_length=${max_prefix_length:-"-1"} \
--max_target_length=${max_target_length:-"192"} \
--num_train_epochs=${epoch} \
--task=${task_name} \
--train_file=${data_folder}/train.json \
--validation_file=${data_folder}/val.json \
--test_file=${data_folder}/test.json \
--record_schema=${data_folder}/record.schema \
--per_device_train_batch_size=${batch_size} \
--per_device_eval_batch_size=$((batch_size * 4)) \
--output_dir=${output_dir} \
--logging_dir=${output_dir}_log \
--logging_strategy="steps" \
--logging_first_step=True \
--logging_steps=100 \
--model_name_or_path=${model_name} \
--learning_rate=${lr} \
--source_prefix="${task_name}: " \
--lr_scheduler_type=${lr_scheduler} \
--label_smoothing_factor=${label_smoothing} \
--eval_steps ${eval_steps} \
--decoding_format ${decoding_format} \
--warmup_ratio ${warmup_ratio} \
--preprocessing_num_workers=32 \
--dataloader_num_workers=32 \
--meta_negative=${negative} \
--meta_positive_rate=${positive} \
--skip_memory_metrics \
--no_remove_unused_columns \
--ordered_prompt=${ordered_prompt} \
--save_better_checkpoint=False \
--start_eval_step=${start_eval_step:-"0"} \
--spot_noise=${spot_noise} \
--asoc_noise=${asoc_noise} \
--seed=${seed}${index} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
echo "exit code:" $?
# --save_strategy=${evaluation_strategy} \
# --save_total_limit 1 \
# --load_best_model_at_end \
# --save_strategy="steps" \
# --save_steps=5000 \
# --save_total_limit 9999999 \
# --load_best_model_at_end=True \
if [[ ${verbose} != true ]]
then
tail -n 200 ${stderr_file}
fi
echo "Map Config" ${map_config}
python3 scripts/sel2record.py -p ${output_dir} -g ${data_folder} -v -d ${decoding_format} -c ${map_config}
python3 scripts/eval_extraction.py -p ${output_dir} -g ${data_folder} -w -m ${eval_match_mode:-"normal"}
# delete all optimizer.pt for saving disk
# find ${output_dir}/ | grep -P "optimizer.pt" | xargs rm -rf
done

View File

@ -0,0 +1,96 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
export batch_size="16"
export model_name=uie-base-en
export data_name=absa/14lap
export task_name="meta"
export decoding_format='spotasoc'
source scripts/function_code.bash
for index in $(seq 1 ${run_time}); do
output_dir=${model_folder}_run${index}
if [[ ${verbose} == true ]]
then
stdout_file=/dev/stdout
stderr_file=/dev/stderr
disable_tqdm=False
else
stdout_file=${output_dir}.log
stderr_file=${output_dir}.err
disable_tqdm=True
fi
ratio_data_folder=${data_folder}_ratio/seed${index}
for ratio in $(ls ${ratio_data_folder})
do
run_data_folder=${ratio_data_folder}/${ratio}
run_output_folder=${output_dir}_${ratio}
if [[ ${max_prefix_length} == 0 ]]
then
run_output_folder=${run_output_folder}_noprefix
fi
eval_steps=$(python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20)
echo Eval each ${eval_steps} batch
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
--do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \
--trainer_type=${trainer_type} \
--load_config_only=False \
--use_fast_tokenizer=True \
--ddp_find_unused_parameters=False \
--predict_with_generate \
--evaluation_strategy=steps \
--save_strategy=steps \
--load_best_model_at_end \
--metric_for_best_model eval_overall-F1 \
--save_total_limit 1 \
--max_source_length=${max_source_length:-"256"} \
--max_prefix_length=${max_prefix_length:-"-1"} \
--max_target_length=${max_target_length:-"192"} \
--num_train_epochs=${epoch} \
--task=${task_name} \
--train_file=${run_data_folder}/train.json \
--validation_file=${run_data_folder}/val.json \
--test_file=${run_data_folder}/test.json \
--record_schema=${run_data_folder}/record.schema \
--per_device_train_batch_size=${batch_size} \
--per_device_eval_batch_size=$((batch_size * 4)) \
--output_dir=${run_output_folder} \
--logging_dir=${run_output_folder}_log \
--model_name_or_path=${model_name} \
--learning_rate=${lr} \
--source_prefix="${task_name}: " \
--lr_scheduler_type=${lr_scheduler} \
--label_smoothing_factor=${label_smoothing} \
--eval_steps ${eval_steps} \
--decoding_format ${decoding_format} \
--warmup_ratio ${warmup_ratio} \
--preprocessing_num_workers=4 \
--dataloader_num_workers=0 \
--meta_negative=${negative} \
--meta_positive_rate=${positive} \
--skip_memory_metrics \
--no_remove_unused_columns \
--ordered_prompt=${ordered_prompt} \
--save_better_checkpoint=True \
--spot_noise=${spot_noise} \
--asoc_noise=${asoc_noise} \
--seed=${seed} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
echo "Map Config" ${map_config}
python3 scripts/sel2record.py -p ${run_output_folder} -g ${run_data_folder} -v -d ${decoding_format} -c ${map_config}
python3 scripts/eval_extraction.py -p ${run_output_folder} -g ${run_data_folder} -w -m ${eval_match_mode:-"normal"}
# delete all pytorch_model.bin of checkpoints in low-resource exps for saving disk
# find ${run_output_folder}/ | grep -P "checkpoint-\d+/pytorch_model.bin" | xargs rm -rf
# delete all optimizer.pt in low-resource exps for saving disk
# find ${run_output_folder}/ | grep -P "optimizer.pt" | xargs rm -rf
done
done

View File

@ -0,0 +1,99 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
export batch_size="16"
export model_name=uie-base-en
export data_name=absa/14lap
export task_name="meta"
export decoding_format='spotasoc'
source scripts/function_code.bash
for index in $(seq 1 ${run_time}); do
output_dir=${model_folder}_run${index}
if [[ ${verbose} == true ]]
then
stdout_file=/dev/stdout
stderr_file=/dev/stderr
disable_tqdm=False
else
stdout_file=${output_dir}.log
stderr_file=${output_dir}.err
disable_tqdm=True
fi
shot_data_folder=${data_folder}_shot/seed${index}
for shot in $(ls ${shot_data_folder})
do
run_data_folder=${shot_data_folder}/${shot}
run_output_folder=${output_dir}_${shot}
if [[ ${max_prefix_length} == 0 ]]
then
run_output_folder=${run_output_folder}_noprefix
fi
echo ${run_data_folder}
eval_steps=$(python scripts/get_eval_batch_num.py ${run_data_folder}/train.json ${batch_size} 20)
echo Eval each ${eval_steps} batch
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} ${run_command} run_seq2seq.py \
--do_train --do_eval --do_predict ${constraint_decoding} ${fp16} \
--trainer_type=${trainer_type} \
--load_config_only=False \
--use_fast_tokenizer=True \
--ddp_find_unused_parameters=False \
--predict_with_generate \
--evaluation_strategy=steps \
--save_strategy=steps \
--load_best_model_at_end \
--metric_for_best_model eval_overall-F1 \
--save_total_limit 1 \
--max_source_length=${max_source_length:-"256"} \
--max_prefix_length=${max_prefix_length:-"-1"} \
--max_target_length=${max_target_length:-"192"} \
--num_train_epochs=${epoch} \
--task=${task_name} \
--train_file=${run_data_folder}/train.json \
--validation_file=${run_data_folder}/val.json \
--test_file=${run_data_folder}/test.json \
--record_schema=${run_data_folder}/record.schema \
--per_device_train_batch_size=${batch_size} \
--per_device_eval_batch_size=$((batch_size * 4)) \
--output_dir=${run_output_folder} \
--logging_dir=${run_output_folder}_log \
--model_name_or_path=${model_name} \
--learning_rate=${lr} \
--source_prefix="${task_name}: " \
--lr_scheduler_type=${lr_scheduler} \
--label_smoothing_factor=${label_smoothing} \
--eval_steps ${eval_steps} \
--decoding_format ${decoding_format} \
--warmup_ratio ${warmup_ratio} \
--preprocessing_num_workers=4 \
--dataloader_num_workers=0 \
--meta_negative=${negative} \
--meta_positive_rate=${positive} \
--skip_memory_metrics \
--no_remove_unused_columns \
--ordered_prompt=${ordered_prompt} \
--save_better_checkpoint=True \
--spot_noise=${spot_noise} \
--asoc_noise=${asoc_noise} \
--seed=${seed} --disable_tqdm=${disable_tqdm} >${stdout_file} 2>${stderr_file}
echo "Map Config" ${map_config}
python3 scripts/sel2record.py -p ${run_output_folder} -g ${run_data_folder} -v -d ${decoding_format} -c ${map_config}
python3 scripts/eval_extraction.py -p ${run_output_folder} -g ${run_data_folder} -w -m ${eval_match_mode:-"normal"}
# delete all pytorch_model.bin of checkpoints in low-resource exps for saving disk
# find ${run_output_folder}/ | grep -P "checkpoint-\d+/pytorch_model.bin" | xargs rm -rf
# delete all optimizer.pt in low-resource exps for saving disk
# find ${run_output_folder}/ | grep -P "optimizer.pt" | xargs rm -rf
done
done

View File

@ -0,0 +1,25 @@
#!/usr/bin/env bash
# -*- coding:utf-8 -*-
# Check Offset Mapping Performance
# 用于验证不同 SEL2Record 回标策略的准确值
# bash scripts/check_offset_map_gold_as_pred.bash data/text2spotasocname/absa/14lap config/offset_map/closest_offset_en.yaml spotasocname
folder_name=$1
config_name=$2
parser_format=$3
cat ${folder_name}/val.json | python -c "import json, sys
for line in sys.stdin:
print(json.loads(line.strip())['record'])
" > ${folder_name}/eval_preds_seq2seq.txt
python scripts/sel2record.py \
-c ${config_name} \
-g ${folder_name} \
-p ${folder_name} \
-d ${parser_format}
python scripts/eval_extraction.py \
-g ${folder_name} \
-p ${folder_name} -w

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import argparse
import json
import os
import sys
import numpy as np
from pprint import pprint
from uie.extraction.scorer import EntityScorer, RelationScorer, EventScorer
def read_file(file_name):
return [line for line in open(file_name).readlines()]
def write_to_file(result, output_filename, prefix=None):
with open(output_filename, 'w') as output:
for key, value in result.items():
if prefix:
key = '%s_%s' % (prefix, key)
output.write("%s=%s\n" % (key, value))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-g', dest='gold_folder', help="Golden Dataset folder")
parser.add_argument('-p', dest='pred_folder', nargs='+', help="Predicted model folder")
parser.add_argument('-v', dest='verbose', action='store_true', help='Show more information during running')
parser.add_argument('-w', dest='write_to_file', action='store_true', help="Write evaluation results to predicted folder")
parser.add_argument('-m', dest='match_mode', default='normal', choices=['set', 'normal', 'multimatch'])
parser.add_argument('-case', dest='case', action='store_true', help='Show case study')
options = parser.parse_args()
data_dict = {
'eval': ['eval_preds_record.txt', 'val.json'],
'test': ['test_preds_record.txt', 'test.json'],
}
task_dict = {
'entity': EntityScorer,
'relation': RelationScorer,
'event': EventScorer,
}
result_list = {'eval': list(), 'test': list()}
for pred_folder in options.pred_folder:
gold_folder = options.gold_folder
for data_key, (generation, gold_file) in data_dict.items():
gold_filename = os.path.join(gold_folder, gold_file)
pred_filename = os.path.join(pred_folder, generation)
if not os.path.exists(pred_filename):
sys.stderr.write("%s not found.\n" % pred_filename)
continue
print("pred:", pred_filename)
print("gold:", gold_filename)
if options.case:
for pred_line, gold_line in zip(read_file(pred_filename), read_file(gold_filename)):
gold_instance = json.loads(gold_line)
pred_instance = json.loads(pred_line)
print('=========================')
print(gold_instance['text'])
for task in task_dict:
scorer = task_dict[task]
gold = scorer.load_gold_list([gold_instance[task]])[0]
pred = scorer.load_pred_list([pred_instance[task]])[0]
min_length = max(
len(gold['string']),
len(pred['string']),
len(gold.get('string_trigger', [])),
len(pred.get('string_trigger', [])),
len(gold.get('string_role', [])),
len(pred.get('string_role', [])),
)
if min_length == 0:
continue
if task == 'entity':
print("Entity Gold:", sorted(gold['string']))
print("Entity Pred:", sorted(pred['string']))
if task == 'relation':
print("Relation Gold:", sorted(gold['string']))
print("Relation Pred:", sorted(pred['string']))
if task == 'event':
print("Event Gold Trigger:", sorted(gold['string_trigger']))
print("Event Pred Trigger:", sorted(pred['string_trigger']))
print("Event Gold Role :", sorted(gold['string_role']))
print("Event Pred Role :", sorted(pred['string_role']))
results = dict()
for task in task_dict:
if task not in json.loads(read_file(pred_filename)[0]):
continue
scorer = task_dict[task]
gold_list = [json.loads(line)[task] for line in read_file(gold_filename)]
pred_list = [json.loads(line)[task] for line in read_file(pred_filename)]
assert len(pred_list) == len(gold_list)
gold_instance_list = scorer.load_gold_list(gold_list)
pred_instance_list = scorer.load_pred_list(pred_list)
assert len(pred_instance_list) == len(gold_instance_list)
sub_results = scorer.eval_instance_list(
gold_instance_list=gold_instance_list,
pred_instance_list=pred_instance_list,
verbose=options.verbose,
match_mode=options.match_mode,
)
results.update(sub_results)
pprint(results)
result_list[data_key] += [results]
if options.write_to_file:
output_filename = "%s/%s_results.txt" % (pred_folder, data_key)
write_to_file(
result=results,
output_filename=output_filename,
prefix=data_key,
)
print("===========> AVG <===========")
for data_key in data_dict:
if len(result_list[data_key]) < 1:
continue
for key in result_list[data_key][0]:
ave = np.mean([result[key] for result in result_list[data_key]])
print(data_key, key, ave)
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show More