nstm package

Subpackages

Submodules

nstm.datasets module

Data utilities and dataset loader for 3D SIM data.

The data loader is designed to load 3D SIM data with OTF and timestamp information. The data loader support loading .tif and .czi file (from Zeiss) format.

class nstm.datasets.Datasets(dir_path: str)

Bases: object

list()
load(data_name, ndim=3, fov_yxhw=None, resize_hw=None)
class nstm.datasets.SIM3DDataLoader(otf_path: str, meta_path: str, zoomfact: float, ndirs: int, nphases: int, start_timepoint: int)

Bases: object

load_3DSIM_OTF(list_otf_path=None)
load_3DSIM_raw(raw_path, fov_zyxshw, noise_std=0, background_int=0, normalize=True, i_timepoint_czi=None)
static load_OTF(otf_path)
load_metadata(normalize=True, avg_phase=True, single_plane_time=False)
meta_path: str
ndirs: int
nphases: int
otf_path: str
start_timepoint: int
zoomfact: float
class nstm.datasets.SIM3DDataLoaderMultitime(otf_path: str, meta_path: str, zoomfact: float, ndirs: int, nphases: int, start_timepoint: int, num_timepoint: int)

Bases: SIM3DDataLoader

list_files(raw_path_regex)
load_3DSIM_raw(raw_path_regex, fov_zyxshw, noise_std=0, background_int=0, normalize=True)
num_timepoint: int
nstm.datasets.image_resizing(I_image, out_dim)
nstm.datasets.image_upsampling(I_image, upsamp_factor=1.0, bg=0)

nstm.diffcam_flow module

Rolling shutter diffuserCam forward model with neural-space time model.

This module contains the forward models for rolling shutter diffuserCam and rolling shutter diffuserCam with space-time modeling. The loss functions used for the rolling shutter diffuserCam reconstruction are also provided.

This script heavily references on the following repository from Nick Antipa:

https://people.eecs.berkeley.edu/~nick.antipa/antipa_files/hsvideo_code_with_data.zip
class nstm.diffcam_flow.DiffuserCamRSFlow(psf: jax._src.basearray.Array, nlines: int, spacetime_param: spacetime.SpaceTimeParameters, annealed_epoch: float = 1, ram_efficient: bool = False, parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

annealed_epoch: float = 1
name: str = None
nlines: int
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
psf: Array
ram_efficient: bool = False
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

spacetime_param: SpaceTimeParameters
class nstm.diffcam_flow.DiffuserCanRS(dim_yx: Tuple[int, int], psf: jax._src.basearray.Array, nlines: int, downsample_t: bool, parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

dim_yx: Tuple[int, int]
downsample_t: bool
efficient(x, t_mask)
name: str = None
nlines: int
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
psf: Array
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

nstm.diffcam_flow.gen_loss_l2()
nstm.diffcam_flow.gen_loss_l2_row()
nstm.diffcam_flow.gen_nonneg_reg()

nstm.diffcam_main module

Main script for rolling-shutter DiffuserCam reconstruction with neural space-time model.

nstm.diffcam_main.main(unused_argv)

nstm.diffcam_utils module

Rolling shutter diffuserCam utility functions.

Heavily referenced from Nick Antipa’s MATLAB code:

https://people.eecs.berkeley.edu/~nick.antipa/antipa_files/hsvideo_code_with_data.zip
nstm.diffcam_utils.define_flags()
nstm.diffcam_utils.gen_indicator(dims, nlines, pad2d, downsample_t=True)
nstm.diffcam_utils.load_data_psf(raw_path, psf_path, background, downsample=8)

nstm.dpc_flow module

Differential phase contrast with space-time modeling.

This module contains the forward models for differential phase contrast (DPC) and differential phase contrast with space-time modeling. The loss functions used for the DPC reconstruction are also provided.

class nstm.dpc_flow.DPC(optical_param: ~nstm.utils.SystemParameters, list_source: ~numpy.ndarray, precision: str = 'float32', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Bases: Model

Differential phase contrast forward model.

Parameters:
  • optical_param (utils.SystemParameters) – Optical parameters of the system.

  • list_source (np.ndarray) – List of illumination patterns used for the DPC system.

  • precision (str, optional) – Precision of the model. Defaults to ‘float32’.

list_source: ndarray
name: str = None
optical_param: SystemParameters
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: str = 'float32'
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class nstm.dpc_flow.DPCFlow(optical_param: ~nstm.utils.SystemParameters, list_source: ~numpy.ndarray, spacetime_param: ~nstm.spacetime.SpaceTimeParameters, annealed_epoch: float = 1, phase_only: bool = False, precision: str = 'float32', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Bases: Model

Differential phase contrast with space-time modeling

Parameters:
  • optical_param (utils.SystemParameters) – Optical parameters of the system.

  • list_source (np.ndarray) – List of illumination patterns used for the DPC system.

  • spacetime_param (spacetime.SpaceTimeParameters) – Space-time modeling parameters.

  • annealed_epoch (float, optional) – The number of annealed epochs for coarse-to-fine optimization. Defaults to 1, i.e., no coarse-to-fine.

  • phase_only (bool, optional) – Whether to use phase-only input. Defaults to False.

  • precision (str, optional) – Precision of the model. Defaults to ‘float32’.

annealed_epoch: float = 1
list_source: ndarray
name: str = None
optical_param: SystemParameters
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
phase_only: bool = False
precision: str = 'float32'
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

spacetime_param: SpaceTimeParameters
nstm.dpc_flow.gen_l2_reg_absorp(freq_space=False)

Returns the L2 regularization function for absorption term.

nstm.dpc_flow.gen_l2_reg_phase(freq_space=False)

Returns the L2 regularization function for phase term.

nstm.dpc_flow.gen_loss_l2(margin=0)

Returns the L2 loss function for the DPC model.

nstm.dpc_utils module

Utility functions for differential phase contrast (DPC) imaging.

This module contains the utility functions for differential phase contrast (DPC) imaging. The functions include loading illumination patterns, generating source patterns, generating transfer functions, and conventional Tikhonov solver.

This script heavily references from these two repositories: https://github.com/Waller-Lab/DPC https://github.com/Waller-Lab/DPC_withAberrationCorrection

nstm.dpc_utils.dpc_tikhonov_solver(imgs, Hu, Hp, amp_reg=5e-05, phase_reg=5e-05, wavelength=0.515)
nstm.dpc_utils.genMeasurmentsLinear(complexField, Hu, Hp)
nstm.dpc_utils.genSourceAngular(sourceCoeffs, rotationAngle, imgSize, systemNa, ps, wavelength)
nstm.dpc_utils.gen_transfer_func(list_source: ndarray, pupil: ndarray, wavelength: float, shifted_out=True)
nstm.dpc_utils.load_illum_pattern(param: SystemParameters, meta, list_led_na, large_led=True, first_n=None)
nstm.dpc_utils.sourceGen(dim_yx, na, ps, wavelength, rotation=None)

Generate DPC source patterns based on the rotation angles and numerical aperture of the illuminations.

nstm.hash_encoding module

Jax implementation of hash encoding for accelerated implicit neural representation for 1D-4D scenes.

Our implementation is based on the paper:

Müller, Thomas, et al. "Instant neural graphics primitives with a multiresolution hash encoding." ACM transactions
    on graphics (TOG) 41.4 (2022): 1-15.

Also a lot of references from this pytorch implementation: https://github.com/yashbhalgat/HashNeRF-pytorch/blob/main/utils.py

class nstm.hash_encoding.AnnealedHashEmbedding(hash_param: nstm.hash_encoding.HashParameters, n_input_features: int = -1, precision: str = 'float32', parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

hash_param: HashParameters
n_input_features: int = -1
name: str = None
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: str = 'float32'
scope = None
setup() None

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class nstm.hash_encoding.HashEmbedding(hash_param: nstm.hash_encoding.HashParameters, n_input_features: int = -1, precision: str = 'float32', parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

encoding_at_level(x, level)
hash_param: HashParameters
n_input_features: int = -1
name: str = None
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: str = 'float32'
scope = None
setup() None

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class nstm.hash_encoding.HashEmbeddingTime(hash_param_space: Optional[nstm.hash_encoding.HashParameters], hash_param_time: nstm.hash_encoding.HashParameters, precision: str = 'float32', parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

hash_param_space: HashParameters | None
hash_param_time: HashParameters
name: str = None
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: str = 'float32'
scope = None
setup() None

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class nstm.hash_encoding.HashEmbeddingTimeCombined(hash_param_spacetime: nstm.hash_encoding.HashParameters, precision: str = 'float32', parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

hash_param_spacetime: HashParameters
name: str = None
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: str = 'float32'
scope = None
setup() None

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class nstm.hash_encoding.HashParameters(bounding_box: Tuple[jax._src.basearray.Array, jax._src.basearray.Array], n_levels: int = 16, n_features_per_level: int = 2, log2_hashmap_size: int = 19, base_resolution: int | numpy.ndarray = 16, finest_resolution: int | numpy.ndarray = 512, init_uniform_std: float = 0.0001)

Bases: object

base_resolution: int | ndarray = 16
bounding_box: Tuple[Array, Array]
finest_resolution: int | ndarray = 512
init_uniform_std: float = 0.0001
log2_hashmap_size: int = 19
n_features_per_level: int = 2
n_levels: int = 16
replace(**updates)

“Returns a new object replacing the specified fields with new values.

nstm.hash_encoding.bilinear_interp(x, pixel_min_vertex, pixel_max_vertex, pixel_embedds)
nstm.hash_encoding.get_pixel_vertices(x, bounding_box, resolution, log2_hashmap_size, box_offsets, box_dim)

x: 1-3D coordinates of samples. B x [1-3] bounding_box: min and max x,y,z coordinates of object bbox resolution: number of voxels per axis

nstm.hash_encoding.hash_fn(coords, log2_hashmap_size)

coords: this function can process upto 7 dim coordinates log2T: logarithm of T w.r.t 2

nstm.hash_encoding.linear_interp(x, pixel_min_vertex, pixel_max_vertex, pixel_embedds)
nstm.hash_encoding.precision_to_dtype(precision)
nstm.hash_encoding.quadrilinear_interp(x, voxel_min_vertex, voxel_max_vertex, voxel_embedds)

x: batch x 4 voxel_min_vertex: batch x 4 voxel_max_vertex: batch x 4 voxel_embedds: batch x 16 x num_feature

nstm.hash_encoding.trilinear_interp(x, voxel_min_vertex, voxel_max_vertex, voxel_embedds)

x: batch x 3 voxel_min_vertex: batch x 3 voxel_max_vertex: batch x 3 voxel_embedds: batch x 8 x num_feature

nstm.pos_encoding module

Jax implementation of the positional encoding for coordinate-based neural networks.

This module heavily references from the following repository:

https://github.com/google-research/google-research/tree/master/jaxnerf
class nstm.pos_encoding.AnnealedPosenc(posenc_param: nstm.pos_encoding.PosencParameters, dim: Optional[jax._src.basearray.Array] = None, parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

dim: Array | None = None
name: str = None
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
posenc_param: PosencParameters
scope = None
class nstm.pos_encoding.PosencParameters(posenc_min: int, posenc_max: int, num_freqs: int)

Bases: object

num_freqs: int
posenc_max: int
posenc_min: int
replace(**updates)

“Returns a new object replacing the specified fields with new values.

nstm.sim3d_flow module

3D SIM forward model with neural-space time model.

This module contains the forward models for 3D SIM and 3D SIM with space-time modeling. The loss functions used for the 3D SIM reconstruction are also provided.

class nstm.sim3d_flow.FluoSIM3D(sim_param: ~nstm.sim3d_utils.SIMParameter3D, optical_param: ~nstm.utils.SystemParameters3D, order0_grad_reduction: float = 0.0, apo_filter: bool = True, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Bases: Model

This function mimicks the setting and forward model used in the original 3D SIM paper (Eq.9) in Biophysical Journal 94(12) 4957–4970.

apo_filter: bool = True
name: str = None
optical_param: SystemParameters3D
order0_grad_reduction: float = 0.0
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sim_param: SIMParameter3D
class nstm.sim3d_flow.FluoSIM3DWrapper(sim_param: nstm.sim3d_utils.SIMParameter3D, optical_param: nstm.utils.SystemParameters3D, parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

name: str = None
optical_param: SystemParameters3D
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sim_param: SIMParameter3D
class nstm.sim3d_flow.SIM3DSpacetime(sim_param: nstm.sim3d_utils.SIMParameter3D, spacetime_param: nstm.spacetime.SpaceTimeParameters, optical_param: nstm.utils.SystemParameters3D, annealed_epoch: float = 1, order0_grad_reduction: float = 0.0, parent: Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7ba611c577f0>, name: str = None)

Bases: Model

annealed_epoch: float = 1
name: str = None
optical_param: SystemParameters3D
order0_grad_reduction: float = 0.0
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sim_param: SIMParameter3D
spacetime_param: SpaceTimeParameters
nstm.sim3d_flow.gen_loss_l2(margin=1)
nstm.sim3d_flow.gen_loss_l2_stack(margin=1)
nstm.sim3d_flow.gen_loss_nonneg_reg()

nstm.sim3d_main module

3D SIM reconstruction with neural-space time model.

Example

To run the 3D SIM reconstruction, download the dataset and run the following command:

$ python ./nstm/sim3d_main.py --config=mito_cell

This will load args stored in examples/configs/mito_cell.yaml and run the 3D SIM reconstruction.

nstm.sim3d_main.define_model(optical_param, sim_param)
nstm.sim3d_main.define_training_params(num_batches_per_epoch)
nstm.sim3d_main.main(unused_argv)

nstm.sim3d_utils module

Utility functions for 3D SIM reconstruction.

Some functions are heavily referenced from the cuda-accelerated three-beam SIM reconstruction code.

class nstm.sim3d_utils.SIMParameter3D(OTF: jax._src.basearray.Array, nphases: int, ndirs: int, k0angles: Tuple[float], line_spacing: Tuple[float], starting_phases: Tuple[Tuple[float]], origin_pixel_offset_yx: Tuple[float])

Bases: object

OTF: Array
k0angles: Tuple[float]
line_spacing: Tuple[float]
ndirs: int
nphases: int
origin_pixel_offset_yx: Tuple[float]
replace(**updates)

“Returns a new object replacing the specified fields with new values.

starting_phases: Tuple[Tuple[float]]
nstm.sim3d_utils.define_flags()

Define flags.

nstm.sim3d_utils.estimate_mod_illum(bandplus_img: ndarray, otf: List[ndarray], img_param: SystemParameters3D, otf_param: SystemParameters3D, ndirs: int, k0angles: Sequence[float], line_spacing: Sequence[float], crop_boundary_zyx: Sequence[int], noisy=True)
nstm.sim3d_utils.gen_dampen_order0_mask(param: SystemParameters3D, inverted=False)
nstm.sim3d_utils.generate_exp(param: SystemParameters | SystemParameters3D, k0angle: float, k0mag: float, phase: float, order: int, origin_pixel_offset_yx: Tuple[float] | None = None)
nstm.sim3d_utils.generate_sinusoidal(param: SystemParameters | SystemParameters3D, k0angle: float, k0mag: float, phase: float, order: int, origin_pixel_offset_yx: Tuple[float] | None = None)

Generate a sinusoidal pattern for 3D SIM.

nstm.sim3d_utils.get_modamp(overlap1: ndarray, overlap2: ndarray, crop_boundary_zyx: Sequence[int] | None = None, intercept: bool = True, drop_half: bool = False)

Estimate for the relative modulation amplitude and starting phase for overlap2 w.r.t. overlap1 (from order 0).

nstm.sim3d_utils.get_otf(bandplus_img: ndarray, otf: List[ndarray], img_param: SystemParameters3D, otf_param: SystemParameters3D, ndirs: int, nphases: int, k0angles: Sequence[float], line_spacing: Sequence[float], crop_boundary_zyx: Sequence[int], noisy=True, notch=False, notch_width=0.5)
nstm.sim3d_utils.make_overlaps(band1, band2, order1, order2, otf, img_param: SystemParameters3D, otf_param: SystemParameters3D, k0angle: float, k0mag: float, phase2: float, normalize_otf=True)

makeOverlaps0Kernel & makeOverlaps1Kernel

nstm.sim3d_utils.otf_support_mask(param: SystemParameters3D, otf_param: SystemParameters3D, sim_param: SIMParameter3D, otf, otf_cutoff: float = 1e-08)

Output a 3D binary mask for possible supported regions of 3D SIM.

nstm.sim3d_utils.rad_avg_OTF_expansion(rad_avg_OTF: ndarray, img_param: SystemParameters3D, otf_param: SystemParameters3D, kxy_shift: Tuple[float] | None = None, freq_cutoff: bool = False, order: int = 0, with_padding: bool = False)
nstm.sim3d_utils.separate_bands(imgs: ndarray, nphases: int = 5, norders: int = 3, out_positive_bands: bool = False)

Separate the bands of the raw SIM images.

Parameters:
  • imgs (np.ndarray) – raw SIM images. [nphases, z, y, x]

  • nphases (int) – number of phases.

  • norders (int) – number of band orders. Default to 3 orders for three-beam 3D SIM.

  • out_positive_bands (bool) – whether to output the positive bands only, or output all bands.

nstm.spacetime module

Implementation of neural space-time model for 2D/3D+time dynamic scene representations.

class nstm.spacetime.MLP(net_depth: int = 8, net_width: int = 256, net_activation: ~typing.Callable = <jax._src.custom_derivatives.custom_jvp object>, skip_layer: int = 4, num_output_channels: int = 1, kernel_init: ~typing.Callable = <function variance_scaling.<locals>.init>, precision: None | str | ~jax._src.lax.lax.Precision | ~typing.Tuple[str, str] | ~typing.Tuple[~jax._src.lax.lax.Precision, ~jax._src.lax.lax.Precision] = None, param_dtype: ~typing.Any = <class 'jax.numpy.float32'>, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Bases: Model

A simple MLP with an option to add skip connections.

Parameters:
  • net_depth (int) – The depth of the MLP.

  • net_width (int) – The width of the MLP.

  • net_activation (Callable) – The activation function.

  • skip_layer (int) – The layer to add skip layers to.

  • num_output_channels (int) – The number of output channels.

  • kernel_init (Callable) – weight initializer

  • precision (nn.linear.PrecisionLike) – arithmetic precision of the network

  • param_dtype (jnp.dtype) – data type of the parameters

kernel_init(shape: ~typing.Sequence[int | ~typing.Any], dtype: ~typing.Any = <class 'jax.numpy.float64'>) Any
name: str = None
net_activation: Callable = <jax._src.custom_derivatives.custom_jvp object>
net_depth: int = 8
net_width: int = 256
num_output_channels: int = 1
param_dtype

alias of float32

parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: None | str | Precision | Tuple[str, str] | Tuple[Precision, Precision] = None
scope = None
skip_layer: int = 4
class nstm.spacetime.MLPParameters(net_depth: int, net_width: int, net_activation: ~typing.Callable[[...], ~typing.Any], skip_layer: int = 4, kernel_init: ~typing.Callable = <function variance_scaling.<locals>.init>)

Bases: object

Parameters for the MLP model.

Parameters:
  • net_depth (int) – The depth of MLP.

  • net_width (int) – The width of MLP.

  • net_activation (Callable) – The activation function.

  • skip_layer (int) – The layer to add skip layers to.

  • kernel_init (Callable) – network weight initializer

kernel_init(shape: ~typing.Sequence[int | ~typing.Any], dtype: ~typing.Any = <class 'jax.numpy.float64'>) Any
net_activation: Callable[[...], Any]
net_depth: int
net_width: int
replace(**updates)

“Returns a new object replacing the specified fields with new values.

skip_layer: int = 4
class nstm.spacetime.SpaceTimeMLP(optical_param: ~nstm.utils.SystemParameters | ~nstm.utils.SystemParameters3D | ~typing.Tuple[int, int] | ~typing.Tuple[int, int, int], spacetime_param: ~nstm.spacetime.SpaceTimeParameters, num_output_channels: int, precision: str = 'float32', parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)

Bases: Model

Implementation of neural space-time model.

The model takes the spatial and temporal coordinate as an input, and outputs the object’s properties at the given time and spatial location.

Parameters:
  • optical_param (Union[utils.SystemParameters, utils.SystemParameters3D, Tuple[int, int], Tuple[int, int, int]) – the optical parameters to specify matrix size. Can alternatively directly specify the matrix size using a tuple of (y,x) or (z,y,x).

  • spacetime_param (SpaceTimeParameters) – the space-time parameters.

  • num_output_channels (int) – the number of output channels.

  • precision (str, optional) – the arithmetic precision of the model.

get_motion_map(t: Array, coord_offset: Array, alpha: float = 100000.0)

Get the pixel/voxel-level motion map at the given timestamps.

Parameters:
  • t (jnp.ndarray) – the list of timestamps to query the space-time model, [batch].

  • coord_offset (jnp.ndarray) – the offset to the spatial coordinates, [batch, (2 or 3)].

  • alpha (float) – the annealing parameter for the positional encoding.

Returns:

the motion map at the given timestamps, 2D - [batch, y, x, 2], 3D - [batch, z, y, x, 3].

Return type:

jnp.ndarray

name: str = None
num_output_channels: int
optical_param: SystemParameters | SystemParameters3D | Tuple[int, int] | Tuple[int, int, int]
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
precision: str = 'float32'
scope = None
setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

spacetime_param: SpaceTimeParameters
class nstm.spacetime.SpaceTimeParameters(motion_mlp_param: MLPParameters, object_mlp_param: MLPParameters, motion_embedding: str | None, motion_embedding_param: Dict | HashParameters | PosencParameters, object_embedding: str | None, object_embedding_param: Dict | HashParameters | PosencParameters, out_activation: Callable[[...], Any])

Bases: object

Parameters for the space-time model.

Parameters:
  • motion_mlp_param (MLPParameters) – Parameters for the motion network.

  • object_mlp_param (MLPParameters) – Parameters for the scene network.

  • motion_embedding (Union[str, None]) – The type of input embedding for motion network.

  • motion_embedding_param (Union[Dict, HashParameters, PosencParameters]) – Parameters for the motion network’s input embedding.

  • object_embedding (Union[str, None]) – The type of input embedding for scene network.

  • object_embedding_param (Union[Dict, HashParameters, PosencParameters]) – Parameters for the object network’s input embedding.

  • out_activation (Callable[..., Any]) – The activation function for MLP output.

motion_embedding: str | None
motion_embedding_param: Dict | HashParameters | PosencParameters
motion_mlp_param: MLPParameters
object_embedding: str | None
object_embedding_param: Dict | HashParameters | PosencParameters
object_mlp_param: MLPParameters
out_activation: Callable[[...], Any]
replace(**updates)

“Returns a new object replacing the specified fields with new values.

nstm.spacetime.generate_dense_yx_coords(dim_yx, normalize=True)

Generate a list of coordinates for a 2D grid.

Parameters:
  • dim_yx (Tuple[int, int]) – the dimension of the 2D grid.

  • normalize (bool) – whether to normalize the coordinates to [-1, 1].

Returns:

the list of coordinates (number of coordinates, [y, x]).

Return type:

np.ndarray

nstm.spacetime.generate_dense_zyx_coords(dim_zyx, start_coord_zyx=None, normalize=True)

Generate a list of coordinates for a 3D grid.

Parameters:
  • dim_zyx (Tuple[int, int, int]) – the dimension of the 3D grid.

  • start_coord_zyx (Tuple[int, int, int]) – the starting coordinate of the 3D grid.

  • normalize (bool) – whether to normalize the coordinates to [-1, 1].

Returns:

the list of coordinates (number of coordinates, [z, y, x]).

Return type:

np.ndarray

nstm.spacetime.get_trajectory(motion_zyx: ndarray, target_pts: List[Tuple[int, int]] | List[Tuple[int, int, int]], interpolate: bool = False)

Helper function to get the trajectory of the target points based on the motion map generated by motion network.

Parameters:
  • motion_zyx (np.ndarray) – a list of motion maps at different timepoints. 2D - [num_frames, y, x, 2], 3D - [num_frames, z, y, x, 3].

  • target_pts (Union[List[Tuple[int, int]], List[Tuple[int, int, int]]]) – the list of target points to track. 2D - [(y, x),]*num_targets, 3D - [(z, y, x)]*num_targets.

  • interpolate (bool) – whether to interpolate the trajectory.

Returns:

the list of trajectories for the target points.

2D - [num_targets, num_frames, 2], 3D - [num_targets, num_frames, 3].

Return type:

List[np.ndarray]

nstm.utils module

General utility functions and dynamic simulation tools used in neural space-time model paper.

nstm.utils.OTF_3D_fluo(param: SystemParameters3D, rfft: bool = False)

Compute for the optical transfer function (OTF) for 3D fluorescence microscopy systems based on angular spectrum propagation. Note that the z direction of the returned matrix is in real domain.

Parameters:
  • param – optical parameter for 3D z-scan system

  • rfft – whether to return OTF for real FFT

Returns:

optical transfer function in (z, fy, fx)

Return type:

OTF

class nstm.utils.PhantomTemporal(param)

Bases: object

generate_bead_phantom(coordinates, phase=1.0)
generate_shepp_logan(coordinates, max_value=1.0, max_phantom=1.0)
generate_shepp_logan_2channel(coordinates)
generate_shepp_logan_affine(coordinates, max_value=1.0)
generate_shepp_logan_swirl(coordinates, max_value=1.0)
generate_usaf_target(coordinates, max_value=1.0)
generate_usaf_target_affine(coordinates, max_value=1.0)
generate_usaf_target_shear(shear: float, scale: float = 0.5, max_value=1.0)
generate_usaf_target_swirl(coordinates, max_value=1.0)
class nstm.utils.SystemParameters(dim_yx: Tuple[int, int], wavelength: float, na: float, pixel_size: float, RI_medium: float, padding_yx: Tuple[int, int] = (0, 0), mean_background_amp: float = 1.0, wavelength_exc: float = 0.5)

Bases: object

Imaging system parameters for 2D imaging systems.

Args:

RI_medium: float
dim_yx: Tuple[int, int]
mean_background_amp: float = 1.0
na: float
padding_yx: Tuple[int, int] = (0, 0)
pixel_size: float
replace(**updates)

“Returns a new object replacing the specified fields with new values.

wavelength: float
wavelength_exc: float = 0.5
class nstm.utils.SystemParameters3D(dim_zyx: Tuple[int, int, int], wavelength: float, wavelength_exc: float, na: float, pixel_size: float, pixel_size_z: float, RI_medium: float, padding_zyx: Tuple[int, int, int])

Bases: object

RI_medium: float
dim_zyx: Tuple[int, int, int]
na: float
padding_zyx: Tuple[int, int, int]
pixel_size: float
pixel_size_z: float
replace(**updates)

“Returns a new object replacing the specified fields with new values.

wavelength: float
wavelength_exc: float
nstm.utils.apodization(img_param: SystemParameters3D | SystemParameters, k0mag: float, norder: int = 3, inverted=False, min_height=0.01)
nstm.utils.apodize_edge(img, napodize=10)
nstm.utils.brownian(x0, n, dt, delta, seed, out=None)

Generate an instance of Brownian motion (i.e. the Wiener process):

X(t) = X(0) + N(0, delta**2 * t; 0, t)

where N(a,b; t0, t1) is a normally distributed random variable with mean a and variance b. The parameters t0 and t1 make explicit the statistical independence of N on different time intervals; that is, if [t0, t1) and [t2, t3) are disjoint intervals, then N(a, b; t0, t1) and N(a, b; t2, t3) are independent.

Written as an iteration scheme,

X(t + dt) = X(t) + N(0, delta**2 * dt; t, t+dt)

If x0 is an array (or array-like), each value in x0 is treated as an initial condition, and the value returned is a numpy array with one more dimension than x0.

Parameters:
  • x0 (float or numpy array (or something that can be converted to a numpy array) – using numpy.asarray(x0)). The initial condition(s) (i.e. position(s)) of the Brownian motion.

  • n (int) – The number of steps to take.

  • dt (float) – The time step.

  • delta (float) – delta determines the “speed” of the Brownian motion. The random variable of the position at time t, X(t), has a normal distribution whose mean is the position at time t=0 and whose variance is delta**2*t.

  • seed (int or generator instance) – Random seed for reproducibility.

  • out (numpy array or None) – If out is not None, it specifies the array in which to put the result. If out is None, a new numpy array is created and returned.

Returns:

  • A numpy array of floats with shape x0.shape + (n,).

  • Note that the initial value x0 is not included in the returned array.

  • Source (https://scipy-cookbook.readthedocs.io/items/BrownianMotion.html)

nstm.utils.generate_affine_motion(t, start_pos_yx, end_pos_yx, rot_start=0, rot_end=0, scale_start=1, scale_end=1, shear_start=0, shear_end=0)
nstm.utils.generate_linear_motion(t, start_pos_yx, end_pos_yx, rot_start=0, rot_end=0)
nstm.utils.generate_rod_phantom(dim_yx: Tuple[int, int], num_rods: int, rod_length: int, rod_width: int, triangle_mix: bool = False, seed: int = 219) ndarray

Draw a phantom with rods of length rod_length and width rod_width on a matrix of size dim_yx.

Parameters:
  • dim_yx – tuple of integers, size of the output matrix

  • num_rods – number of rods to draw

  • rod_length – length of the rod

  • rod_width – width of the rod

  • triangle_mix – whether to draw half as triangle rods

  • seed – random seed for reproducibility

Returns:

phantom matrix with rods

Return type:

obj

nstm.utils.hotpixel_removal(imgs, int_thres_percentile=85, pixel_on_rate=0.9, detect_only=False)
nstm.utils.load_video(filename, fov=None, single_channel=False, target_dim=None)
nstm.utils.notch_filter(img_param: SystemParameters3D, order: int, d: float, w: float, kz_offset: float = 0, inverted=False)
nstm.utils.object_transform(obj: ndarray, target_dim_yx: Tuple[int, int], coord: List[float] | Tuple[float, float, float, float]) ndarray

Linear transformation of a given object.

Parameters:
  • obj – original object to be transformed

  • target_dim_yx – output matrix dimension

  • coord – tuple or list to specify the transformation as (x, y, orientation, scale)

Return obj_transformed:

transformed object

nstm.utils.object_transform_affine(obj: ndarray, target_dim_yx: Tuple[int, int], coord: List[float] | Tuple[float, float, float, float, float]) ndarray

Affine transformation of a given object.

Parameters:
  • obj – original object to be transformed

  • target_dim_yx – output matrix dimension

  • coord – tuple or list to specify the transformation as (x, y, orientation, scale, shear)

Return obj_transformed:

transformed object

nstm.utils.object_transform_swirl(obj: ndarray, target_dim_yx: Tuple[int, int], scale, strength, radius)
nstm.utils.psf_gaussian_approx(dim_yx: Tuple[int, int], pixel_size: float, na: float, wavelength: float, paraxial: bool = True, ri: float = 1.0) ndarray

Generate Gaussian-approx PSF. Based on: https://opg.optica.org/ao/fulltext.cfm?uri=ao-46-10-1819&id=130945

Parameters:
  • dim_yx – y-x matrix dimensions of the output

  • pixel_size – pixel size in micron

  • na – numerical aperture

  • wavelength – emission wavelength in micron

  • paraxial – whether to use paraxial approximation

  • ri – refractive index of the medium

Returns:

Gaussian-approximated PSF in 2D

Return type:

psf

nstm.utils.update_flags(args)

Update the flags in args with the contents of the config YAML file.

Copied from https://github.com/google-research/google-research/blob/master/jaxnerf/nerf/utils.py

nstm.viz_utils module

Common tools for interactive 3D/4D visualization on Jupyter Notebook and Jupyter Lab.

nstm.viz_utils.add_inset(ax_, mat, extent)
nstm.viz_utils.add_scalebar(ax_, length, label, loc='upper left', pad=0.2, color='white', size_vertical=1, font_size=9, label_top=True)
nstm.viz_utils.color_coded_projection(stack_3d)

Convert a 3D microscopy image stack to a color-coded 2D projection.

Args: - stack_3d (numpy.ndarray): A 3D numpy array representing the image stack.

Returns: - numpy.ndarray: A 2D color-coded projection image.

nstm.viz_utils.volume(v, cmap='magma')
nstm.viz_utils.volume4d(v, cmap='magma')

Module contents