lammps/examples/mliap/jax
rohskopf 16751b06b3 Add os env vars to allow more MPI procs 2023-05-26 12:52:51 -06:00
..
README.md Format readme 2023-05-26 11:21:13 -06:00
deploy_script.py Copy MLIAPUnified LJ example 2023-05-20 14:08:20 -06:00
in.run Add Kokkos example 2023-05-26 10:48:58 -06:00
mliap_jax.pkl Initial example 2023-05-20 13:53:22 -06:00
mliap_unified_jax.py Add os env vars to allow more MPI procs 2023-05-26 12:52:51 -06:00
mliap_unified_jax_Ar.pkl Copy MLIAPUnified LJ example 2023-05-20 14:08:20 -06:00
mliap_unified_jax_kokkos.py Add os env vars to allow more MPI procs 2023-05-26 12:52:51 -06:00
write_unified.py Copy MLIAPUnified LJ example 2023-05-20 14:08:20 -06:00

README.md

Running JAX from LAMMPS

Getting started

First make a Python environment with dependencies:

conda create --name jax python=3.10
conda activate jax
# Upgrade pip
python -m pip install --upgrade pip
# Install JAX:
python -m pip install --upgrade "jax[cpu]"
# Install other dependencies:
python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython

Install LAMMPS:

cd /path/to/lammps
mkdir build-jax; cd build-jax
cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
               -DBUILD_SHARED_LIBS=yes \
               -DMLIAP_ENABLE_PYTHON=yes \
               -DPKG_PYTHON=yes \
               -DPKG_ML-SNAP=yes \
               -DPKG_ML-IAP=yes \
               -DPYTHON_EXECUTABLE:FILEPATH=`which python`
make -j4
make install-python

Kokkos install

Use same Python dependencies as above, with some extra changes:

  1. Make sure you install cupy properly! E.g.

     python -m pip install cupy-cuda12x
    
  2. Install JAX for GPU/CUDA:

     python -m pip install --trusted-host storage.googleapis.com --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    
  3. Install cudNN: https://developer.nvidia.com/cudnn

Install LAMMPS. Take care to change Kokkos_ARCH_* flag:

cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
              -DBUILD_SHARED_LIBS=yes \
              -DPKG_PYTHON=yes \
              -DPKG_ML-SNAP=yes \
              -DPKG_ML-IAP=yes \
              -DMLIAP_ENABLE_PYTHON=yes \
              -DPKG_KOKKOS=yes \
              -DKokkos_ARCH_TURING75=yes \
              -DKokkos_ENABLE_CUDA=yes \
              -DKokkos_ENABLE_OPENMP=yes \
              -DCMAKE_CXX_COMPILER=${HOME}/lammps/lib/kokkos/bin/nvcc_wrapper \
              -DPYTHON_EXECUTABLE:FILEPATH=`which python`
make -j
make install-python

Run example:

mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run

Deploying JAX models on CPU

Use deploy_script.py, which will wrap model with write_unified_jax.

python deploy_script.py

This creates .pkl file to be loaded by LAMMPS ML-IAP Unified.

Run LAMMPS with the model:

mpirun -np P lmp -in in.run

Deploying JAX models in Kokkos

Use deploy_script_kokkos.py, which will wrap model with write_unified_jax_kokkos.

python deploy_script_kokkos.py

This creates .pkl file to be loaded by LAMMPS ML-IAP Unified.

Run LAMMPS with the model:

mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run