Skip to content

Commit 47474ce

Browse files
committed
follow-up #141: tweak ConvergentSource interface and add tests
1 parent 586d45a commit 47474ce

2 files changed

Lines changed: 133 additions & 7 deletions

File tree

pytissueoptics/rayscattering/source.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,18 +341,22 @@ class ConvergentSource(DirectionalSource):
341341
def __init__(
342342
self,
343343
position: Vector,
344-
focal_point: Vector,
344+
direction: Vector,
345345
diameter: float,
346+
focalLength: float,
346347
N: int,
347348
useHardwareAcceleration: bool = True,
348349
displaySize: float = 0.1,
349350
seed: Optional[int] = None,
350351
):
351-
self._focal_point = focal_point
352+
if focalLength <= 0:
353+
raise ValueError("The focal length of a convergent source must be positive.")
354+
355+
self._focalLength = focalLength
352356

353357
super().__init__(
354358
position=position,
355-
direction=focal_point - position,
359+
direction=direction,
356360
diameter=diameter,
357361
N=N,
358362
useHardwareAcceleration=useHardwareAcceleration,
@@ -361,11 +365,12 @@ def __init__(
361365
)
362366

363367
def getInitialPositionsAndDirections(self) -> Tuple[np.ndarray, np.ndarray]:
364-
positions = self._getUniformlySampledDisc(self._diameter) + self._position.array
365-
directions = self._focal_point.array - positions
368+
positions = self._getInitialPositions()
369+
focalPoint = self._position + self._direction * self._focalLength
370+
directions = focalPoint.array - positions
366371
directions /= np.linalg.norm(directions, axis=1, keepdims=True)
367372
return positions, directions
368373

369374
@property
370375
def _hashComponents(self) -> tuple:
371-
return self._position, self._direction, self._diameter, self._focal_point
376+
return self._position, self._direction, self._diameter, self._focalLength

pytissueoptics/rayscattering/tests/testSource.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from pytissueoptics.rayscattering import EnergyLogger, PencilPointSource, Photon
99
from pytissueoptics.rayscattering.materials import ScatteringMaterial
1010
from pytissueoptics.rayscattering.scatteringScene import ScatteringScene
11-
from pytissueoptics.rayscattering.source import DirectionalSource, DivergentSource, IsotropicPointSource, Source
11+
from pytissueoptics.rayscattering.source import (
12+
DirectionalSource,
13+
DivergentSource,
14+
IsotropicPointSource,
15+
Source,
16+
ConvergentSource,
17+
)
1218
from pytissueoptics.scene.geometry import Environment, Vector
1319
from pytissueoptics.scene.logger import Logger
1420
from pytissueoptics.scene.solids import Solid
@@ -252,3 +258,118 @@ def testGivenTwoDivergentSourcesThatDifferInDivergence_shouldNotHaveSameHash(sel
252258
position=Vector(), direction=sourceDirection, diameter=1, divergence=divergence2, N=1
253259
)
254260
self.assertNotEqual(hash(divergentSource1), hash(divergentSource2))
261+
262+
263+
class TestConvergentSource(unittest.TestCase):
264+
def testGivenNegativeOrZeroFocalLength_shouldRaiseValueError(self):
265+
with self.assertRaises(ValueError):
266+
ConvergentSource(
267+
position=Vector(),
268+
direction=Vector(0, 0, 1),
269+
focalLength=0,
270+
diameter=1,
271+
N=1,
272+
useHardwareAcceleration=False,
273+
)
274+
with self.assertRaises(ValueError):
275+
ConvergentSource(
276+
position=Vector(),
277+
direction=Vector(0, 0, 1),
278+
focalLength=-1,
279+
diameter=1,
280+
N=1,
281+
useHardwareAcceleration=False,
282+
)
283+
284+
def testShouldHavePhotonsPointingTowardTheFocalPoint(self):
285+
np.random.seed(0)
286+
position = Vector(0, 0, 0)
287+
direction = Vector(0, 0, 1)
288+
focalLength = 5.0
289+
diameter = 2.0
290+
source = ConvergentSource(
291+
position=position,
292+
direction=direction,
293+
focalLength=focalLength,
294+
diameter=diameter,
295+
N=10,
296+
useHardwareAcceleration=False,
297+
)
298+
299+
focalPoint = position + direction * focalLength
300+
for photon in source.photons:
301+
expectedDirection = focalPoint - photon.position
302+
expectedDirection.normalize()
303+
self.assertEqual(expectedDirection, photon.direction)
304+
305+
def testGivenInfiniteFocalLength_shouldHavePhotonsAllPointingInTheSourceDirection(self):
306+
sourceDirection = Vector(1, 0, 0)
307+
source = ConvergentSource(
308+
position=Vector(),
309+
direction=sourceDirection,
310+
focalLength=1e10,
311+
diameter=1.0,
312+
N=10,
313+
useHardwareAcceleration=False,
314+
)
315+
for photon in source.photons:
316+
self.assertEqual(sourceDirection, photon.direction)
317+
318+
def testShouldHavePhotonsUniformlyPositionedInsideTheSourceDiameter(self):
319+
np.random.seed(0)
320+
sourcePosition = Vector(3, 3, 0)
321+
sourceDiameter = 2.0
322+
source = ConvergentSource(
323+
position=sourcePosition,
324+
direction=Vector(0, 1, 0),
325+
focalLength=5.0,
326+
diameter=sourceDiameter,
327+
N=10,
328+
useHardwareAcceleration=False,
329+
)
330+
for photon in source.photons:
331+
self.assertTrue(np.isclose(photon.position.y, sourcePosition.y))
332+
self.assertTrue(
333+
sourcePosition.x - sourceDiameter / 2 <= photon.position.x <= sourcePosition.x + sourceDiameter / 2
334+
)
335+
self.assertTrue(
336+
sourcePosition.z - sourceDiameter / 2 <= photon.position.z <= sourcePosition.z + sourceDiameter / 2
337+
)
338+
339+
def testGivenTwoConvergentSourcesWithSamePropertiesExceptPhotonCount_shouldHaveSameHash(self):
340+
source1 = ConvergentSource(
341+
position=Vector(),
342+
direction=Vector(0, 0, 1),
343+
focalLength=5.0,
344+
diameter=1.0,
345+
N=1,
346+
useHardwareAcceleration=False,
347+
)
348+
source2 = ConvergentSource(
349+
position=Vector(),
350+
direction=Vector(0, 0, 1),
351+
focalLength=5.0,
352+
diameter=1.0,
353+
N=2,
354+
useHardwareAcceleration=False,
355+
)
356+
self.assertEqual(hash(source1), hash(source2))
357+
358+
def testGivenTwoConvergentSourcesThatDifferInFocalLength_shouldNotHaveSameHash(self):
359+
source1 = ConvergentSource(
360+
position=Vector(),
361+
direction=Vector(0, 0, 1),
362+
focalLength=5.0,
363+
diameter=1.0,
364+
N=1,
365+
useHardwareAcceleration=False,
366+
)
367+
source2 = ConvergentSource(
368+
position=Vector(),
369+
direction=Vector(0, 0, 1),
370+
focalLength=10.0,
371+
diameter=1.0,
372+
N=1,
373+
useHardwareAcceleration=False,
374+
)
375+
self.assertNotEqual(hash(source1), hash(source2))

0 commit comments

Comments
 (0)