Installation
Prerequisites
Step-by-step Installation
Create a conda virtual environment and activate it
$ conda create -n nstm python=3.9 $ conda activate nstm
Clone this project to your local machine. Or download the zip file and unzip it.
$ git clone https://github.com/rmcao/nstm.git
Install CUDA and cuDNN in conda virtual env (you may opt to skip this step if you have CUDA installed in your system and you know what you are doing)
$ conda install -c conda-forge cudatoolkit~=11.8.0 cudnn~=8.8.0 $ conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
Install jaxlib. Note that the following command is for CUDA 11.x and cuDNN 8.2+. If you have different versions of CUDA, please refer to JAX installation guide and make sure to match the version numbers of jaxlib and jax (as specified in requirements.txt).
$ pip install 'numpy<2.0.0' $ pip install jaxlib==0.3.18+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Install optional dependencies for interactive visualization via Jupyter lab
$ conda install -c conda-forge jupyterlab nodejs ipympl
Install the helper library and this codebase
$ pip install git+https://github.com/rmcao/CalCIL.git $ pip install -e ./nstm
Test the installation
$ python -c "import jax.numpy as jnp; print(jnp.ones(5)+jnp.zeros(5))"