Build your own nstm reconstruction
This tutorial will guide you through the process of building your own space-time reconstruction using nstm method.
Write forward model
NSTM relies on a deterministic and differentiable forward model to get meaningful gradients to update the model’s weights during reconstruction.
Here we will use a simple linear model to demonstrate the process. The forward model is defined as:
\[y = x \cdot w\]
where \(x\) is the object to recover and \(w\) is a pre-defined weight matrix. The forward model is implemented as follows:
import jax.numpy as jnp import calcil as cc class NewImager(cc.forward.Model): w: jnp.ndarray # pre-defined weight matrix def setup(): # nothing to do here pass def __call__(self, x): """Forward model""" return x @ self.w
The forward model is a class that inherits from cc.forward.Model. The setup method is used to initialize the model’s parameters. The __call__ method is the forward model itself. The forward model takes an input x and returns the output y.
Combine with space-time model
Once the forward model is defined, we can combine it with the nstm method to render the measurement at different timepoints.
from nstm import spacetime class NewImagerWithNSTM(cc.forward.Model): w: jnp.ndarray # pre-defined weight matrix spacetime_param: spacetime.SpaceTimeParameters # nstm parameters def setup(): self.forward = NewImager(w=self.w) # Initialize the forward model # Initialize the nstm reconstruction self.spacetime = spacetime.SpaceTimeMLP(optical_param=self.w.shape, # specify the shape and dim of the nstm reconstruction spacetime_param=self.spacetime_param, # nstm parameters num_output_channels=1) # assume the output is a single channel def __call__(self, input_dict): t = input_dict['t'] # given timepoint obj = self.spacetime(t, coord_offset=jnp.zeros((1, 2))) img = self.forward(obj) return img
The NewImagerWithNSTM class combines the forward model with the nstm method. The setup method initializes the forward model and the nstm reconstruction.
The __call__ method takes an input timepoint and feeds it into the nstm to get the reconstructed object at that timepoint.
The forward model then takes the reconstructed object and returns the rendered measurement.
Note
Isn’t that simple? Now you can build your own nstm reconstruction with your favorite imaging system!
All left is to define a loss function and train the network weights of nstm using your data. We found that L2 loss between the rendered measurement and the actual measurement works well.
If you’re bored with copying and pasting template code to do gradient descent-based image reconstruction, calcil package can do some brainless work for you 🧠.