PyO3: Add CI (#1135)

* Add PyO3 ci

* Update python.yml

* Format `bert.py`
This commit is contained in:
Lukas Kreussel 2023-10-20 20:05:14 +02:00 committed by GitHub
parent 7366aeac21
commit cfb423ab76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 3 deletions

62
.github/workflows/python.yml vendored Normal file
View File

@ -0,0 +1,62 @@
name: PyO3-CI
on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**
jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"
- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests

View File

@ -59,8 +59,7 @@ class BertSelfAttention(Module):
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
if attention_mask is not None:
b_size, _, _, last_dim = attention_scores.shape
attention_scores = attention_scores.broadcast_add(
attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = attention_probs.matmul(value)
@ -198,7 +197,9 @@ class BertModel(Module):
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
def forward(
self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None
) -> Tuple[Tensor, Optional[Tensor]]:
if attention_mask is not None:
# Replace 0s with -inf, and 1s with 0s.
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)