Build your own nstm reconstruction

This tutorial will guide you through the process of building your own space-time reconstruction using nstm method.

Write forward model

NSTM relies on a deterministic and differentiable forward model to get meaningful gradients to update the model’s weights during reconstruction.

Here we will use a simple linear model to demonstrate the process. The forward model is defined as:

\[y = x \cdot w\]

where \(x\) is the object to recover and \(w\) is a pre-defined weight matrix. The forward model is implemented as follows:

import jax.numpy as jnp
import calcil as cc

class NewImager(cc.forward.Model):
    w: jnp.ndarray  # pre-defined weight matrix

    def setup():
        # nothing to do here
        pass

    def __call__(self, x):
        """Forward model"""
        return x @ self.w

The forward model is a class that inherits from cc.forward.Model. The setup method is used to initialize the model’s parameters. The __call__ method is the forward model itself. The forward model takes an input x and returns the output y.

Combine with space-time model

Once the forward model is defined, we can combine it with the nstm method to render the measurement at different timepoints.

from nstm import spacetime

class NewImagerWithNSTM(cc.forward.Model):
    w: jnp.ndarray  # pre-defined weight matrix
    spacetime_param: spacetime.SpaceTimeParameters  # nstm parameters

    def setup():
        self.forward = NewImager(w=self.w)  # Initialize the forward model

        # Initialize the nstm reconstruction
        self.spacetime = spacetime.SpaceTimeMLP(optical_param=self.w.shape,  # specify the shape and dim of the nstm reconstruction
                                                spacetime_param=self.spacetime_param,  # nstm parameters
                                                num_output_channels=1)  # assume the output is a single channel

    def __call__(self, input_dict):
        t = input_dict['t']  # given timepoint
        obj = self.spacetime(t, coord_offset=jnp.zeros((1, 2)))
        img = self.forward(obj)
        return img

The NewImagerWithNSTM class combines the forward model with the nstm method. The setup method initializes the forward model and the nstm reconstruction. The __call__ method takes an input timepoint and feeds it into the nstm to get the reconstructed object at that timepoint. The forward model then takes the reconstructed object and returns the rendered measurement.

Note

Isn’t that simple? Now you can build your own nstm reconstruction with your favorite imaging system!

All left is to define a loss function and train the network weights of nstm using your data. We found that L2 loss between the rendered measurement and the actual measurement works well.

If you’re bored with copying and pasting template code to do gradient descent-based image reconstruction, calcil package can do some brainless work for you 🧠.