Installation

Prerequisites

Step-by-step Installation

  1. Create a conda virtual environment and activate it

    $ conda create -n nstm python=3.9
    $ conda activate nstm
    
  2. Clone this project to your local machine. Or download the zip file and unzip it.

    $ git clone https://github.com/rmcao/nstm.git
    
  3. Install CUDA and cuDNN in conda virtual env (you may opt to skip this step if you have CUDA installed in your system and you know what you are doing)

    $ conda install -c conda-forge cudatoolkit~=11.8.0 cudnn~=8.8.0
    $ conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
    
  4. Install jaxlib. Note that the following command is for CUDA 11.x and cuDNN 8.2+. If you have different versions of CUDA, please refer to JAX installation guide and make sure to match the version numbers of jaxlib and jax (as specified in requirements.txt).

    $ pip install 'numpy<2.0.0'
    $ pip install jaxlib==0.3.18+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    
  5. Install optional dependencies for interactive visualization via Jupyter lab

    $ conda install -c conda-forge jupyterlab nodejs ipympl
    
  6. Install the helper library and this codebase

    $ pip install git+https://github.com/rmcao/CalCIL.git
    $ pip install -e ./nstm
    
  7. Test the installation

    $ python -c "import jax.numpy as jnp; print(jnp.ones(5)+jnp.zeros(5))"