Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 128 additions & 10 deletions simSPI/linear_simulator/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,136 @@ def __init__(self, config):
super(Projector, self).__init__()

self.config = config
self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.float32)
lin_coords = torch.linspace(-1.0, 1.0, self.config.side_len)
[x, y, z] = torch.meshgrid(
[
lin_coords,
]
* 3
)
coords = torch.stack([y, x, z], dim=-1)
self.register_buffer("vol_coords", coords.reshape(-1, 3))
self.space = config.space

if self.space == "real":
self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.float32)
lin_coords = torch.linspace(-1.0, 1.0, self.config.side_len)
[x, y, z] = torch.meshgrid(
[
lin_coords,
]
* 3
)
coords = torch.stack([y, x, z], dim=-1)

self.register_buffer("vol_coords", coords.reshape(-1, 3))
elif self.space == "fourier":
# Assume DC coefficient is at self.vol[n//2+1,n//2+1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment could be moved in a docstring, so that it is more visible by users --- when we will publish the documentation website with Sphinx.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relates a bit to my comment in issue #102, we should try to standardize somewhere the conventions for Fourier space representations. Then it doesn't really need to be anywhere... I'm also not sure that this is currently accurate.

# this means that self.vol = fftshift(fft3(fftshift(real_vol)))
self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.complex64)
freq_coords = torch.fft.fftfreq(self.config.side_len, dtype=torch.float32)
[x, y] = torch.meshgrid(
[
freq_coords,
]
* 2
)
coords = torch.stack([y, x], dim=-1)
# Rescale coordinates to [-1,1] to be compatible with
# torch.nn.functional.grid_sample
coords = 2 * coords
self.register_buffer("vol_coords", coords.reshape(-1, 2))

def forward(self, rot_params, proj_axis=-1):
"""Forward method for projection.

Parameters
----------
rot_params : tensor of rotation matrices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description of rot_params does not strictly follow the docstring convention:

  • if it is a tensor, what is its shape
  • it should be on two lines, where the first line explains the type of datastructure and the second explains what the parameter represents.

This will become important to generate a clean documentation website.

Additionally, this description says that rot_params is a tensor; yet _forward_fourier docstring mentions a dict: which one is true?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this docstring could point to _forward_fourier docstring about the details? I find it well explained there that rot_params is a dictionary that contains a tensor of specific shape.

"""
if self.space == "real":
return self._forward_real(rot_params, proj_axis)
elif self.space == "fourier":
return self._forward_fourier(rot_params)

def _forward_fourier(self, rot_params):
"""Output the tomographic projection of the volume in Fourier space.

Take a slide through the Fourier space volume whose normal is
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: slice

oriented according to rot_params. The volume is assumed to be cube
represented in the fourier space. The output image follows
(batch x channel x height x width) convention of pytorch. Therefore,
a dummy channel dimension is added at the end to projection.

Parameters
----------
rot_params: dict of type str to {tensor}
Dictionary containing parameters for rotation, with keys
rotmat: str map to tensor
rotation matrix (batch_size x 3 x 3) to rotate the volume

Returns
-------
projection: tensor
Tensor containing tomographic projection in the Fourier domain
(batch_size x 1 x sidelen x sidelen)

Comments
--------
Note that the Fourier volumes are arbitrary
channel x height x width complex valued tensors,
they are not assumed to be Fourier transforms of a real valued 3D functions.

Note that the tomographic projection is interpolated on a rotated 2D grid.
The rotated 2D grid extends outside the boundaries of the 3D grid.
The values outside the boundaries are not defined in a useful way.
Therefore, in most applications, it make sense to apply a radial filter
to the sample.

"""
rotmat = rot_params["rotmat"]
batch_sz = rotmat.shape[0]

# print(rotmat[0])
rotmat = torch.transpose(rotmat, -1, -2)
# print(rotmat[0])
rot_vol_coords = self.vol_coords.repeat((batch_sz, 1, 1)).bmm(rotmat[:, :2, :])
# print(rot_vol_coords[0])
# print(rot_vol_coords[1])

# rescale the coordinates to be compatible with the edge alignment of
# torch.nn.functional.grid_sample
if 0 == self.config.side_len % 2: # even case
rot_vol_coords = (
(rot_vol_coords + 1)
* (self.config.side_len)
/ (self.config.side_len - 1)
) - 1
else: # odd case
rot_vol_coords = (
(rot_vol_coords) * (self.config.side_len) / (self.config.side_len - 1)
)

projection = torch.empty(
(batch_sz, self.config.side_len, self.config.side_len),
dtype=torch.complex64,
)
# interpolation is decomposed to real and imaginary parts due to torch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be put in a docstring?

# grid_sample type rules. Requires data and coordinates of same type.
# padding_mode="reflection" is required due to possible pathologies
# right on the border.
# however, padding_mode="zeros" is what users might expect in most
# cases other than these axis aligned cases.
padding_mode = "zeros"
projection.real = torch.nn.functional.grid_sample(
self.vol.real.repeat((batch_sz, 1, 1, 1, 1)),
rot_vol_coords[:, None, None, :, :],
align_corners=True,
padding_mode=padding_mode,
).reshape(batch_sz, self.config.side_len, self.config.side_len)

projection.imag = torch.nn.functional.grid_sample(
self.vol.imag.repeat((batch_sz, 1, 1, 1, 1)),
rot_vol_coords[:, None, None, :, :],
align_corners=True,
padding_mode=padding_mode,
).reshape(batch_sz, self.config.side_len, self.config.side_len)

projection = projection[:, None, :, :]
return projection

def _forward_real(self, rot_params, proj_axis=-1):
"""Output the tomographic projection of the volume.

First rotate the volume and then sum it along an axis.
Expand Down
40 changes: 39 additions & 1 deletion tests/test_projector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test function for projector module."""

import numpy as np
import torch

from simSPI.linear_simulator.projector import Projector

Expand Down Expand Up @@ -40,6 +41,7 @@ def init_data(path):
config_dict = saved_data["config_dict"]
else:
config_dict = {}
config_dict["space"] = "real"
config = AttrDict(config_dict)
return saved_data, config

Expand All @@ -63,15 +65,51 @@ def normalized_mse(a, b):
return (a - b).pow(2).sum().sqrt() / a.pow(2).sum().sqrt()


def test_projector():
def test_projector_real():
"""Test accuracy of projector function."""
path = "tests/data/projector_data.npy"

saved_data, config = init_data(path)
config.space = "real"
rot_params = saved_data["rot_params"]
projector = Projector(config)
projector.vol = saved_data["volume"]

out = projector(rot_params)
error = normalized_mse(saved_data["projector_output"], out).item()
assert (error < 0.01) == 1


def test_projector_fourier():
"""Test accuracy of projector function.

Note: corrent test only checks that the scaling is compatible.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: current

"""
path = "tests/data/projector_data.npy"

saved_data, config = init_data(path)
config.space = "fourier"
rot_params = saved_data["rot_params"]
projector = Projector(config)
projector.vol = torch.fft.fftshift(
torch.fft.fftn(torch.fft.fftshift(saved_data["volume"], dim=[-3, -2, -1])),
dim=[-3, -2, -1],
)

sz = projector.vol.shape[0]

out = projector(rot_params)
fft_proj_out = torch.fft.fft2(
torch.fft.fftshift(saved_data["projector_output"], dim=(2, 3))
)

print(out.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid prints in the tests because our logs might become overcrowded when we add more tests: could these become asserts?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, maybe these could be replaced with logging, so they could be displayed if really needed?

print("ratio", sz, (fft_proj_out.real / out.real).median())
print("ratio", sz, 1 / (fft_proj_out.real[0, 0, 0, 0] / out.real[0, 0, 0, 0]))
print("ratio", sz, 1 / (fft_proj_out.real[:, 0, 0, 0] / out.real[:, 0, 0, 0]))
assert 0.01 > (fft_proj_out.real[0, 0, 0, 0] / out.real[0, 0, 0, 0] - 1).abs()
# print(out.shape[0],np.sqrt(out.shape[0]),1.0/np.sqrt(out.shape[0]))
# error_r = normalized_mse(fft_proj_out.real, out.real).item()
# error_i = normalized_mse(fft_proj_out.imag, out.imag).item()
# assert (error_r < 0.01) == 1
# assert (error_i < 0.01) == 1