From 03d95e33d6b947b2828a83d9b95fe1ba2d188561 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Mon, 1 Dec 2025 10:25:21 -0700 Subject: [PATCH 1/5] Accelerated `optika.materials.snells_law()` using Numba. --- optika/materials/_snells_law.py | 110 +++++++++++++++++++++++++++----- 1 file changed, 94 insertions(+), 16 deletions(-) diff --git a/optika/materials/_snells_law.py b/optika/materials/_snells_law.py index ea9a268..8ea0c0b 100644 --- a/optika/materials/_snells_law.py +++ b/optika/materials/_snells_law.py @@ -1,5 +1,6 @@ -from __future__ import annotations +import math import numpy as np +import numba as nb import astropy.units as u import named_arrays as na @@ -271,22 +272,99 @@ def snells_law( \pm \sqrt{\left( n_2 / n_1 \right)^2 + (\mathbf{k}_1 \cdot \hat{\mathbf{n}})^2 - k_1^2 } \right) \hat{\mathbf{n}} \right]} """ - a = direction - n1 = index_refraction # noqa: F841 - n2 = index_refraction_new # noqa: F841 - if normal is None: normal = na.Cartesian3dVectorArray(0, 0, -1) - a_x = a.x # noqa: F841 - a_y = a.y # noqa: F841 - a_z = a.z # noqa: F841 - u_x = normal.x # noqa: F841 - u_y = normal.y # noqa: F841 - u_z = normal.z # noqa: F841 - - return na.numexpr.evaluate( - "(n1 / n2) * (a + (-(a_x*u_x + a_y*u_y + a_z*u_z) + sign(-(a_x*u_x + a_y*u_y + a_z*u_z)) " - "* (2 * is_mirror - 1) * sqrt(1 / (n1 / n2)**2 + (a_x*u_x + a_y*u_y + a_z*u_z)**2" - "- (a_x*a_x + a_y*a_y + a_z*a_z))) * normal)" + direction = direction << u.dimensionless_unscaled + index_refraction = index_refraction << u.dimensionless_unscaled + index_refraction_new = index_refraction_new << u.dimensionless_unscaled + normal = normal << u.dimensionless_unscaled + + b_x, b_y, b_z = _snells_law_numba( + direction.x.value, + direction.y.value, + direction.z.value, + index_refraction.value, + index_refraction_new.value, + normal.x.value, + normal.y.value, + normal.z.value, + is_mirror, ) + + return na.Cartesian3dVectorArray(b_x, b_y, b_z) + + +@nb.guvectorize( + [ + "void(float64,float64,float64,float64,float64,float64,float64,float64,bool,float64[:],float64[:],float64[:])" + ], + "(),(),(),(),(),(),(),(),()->(),(),()", + target="parallel", + nopython=True, + cache=True, +) +def _snells_law_numba( + direction_x: float, + direction_y: float, + direction_z: float, + index_refraction: float, + index_refraction_new: float, + normal_x: float, + normal_y: float, + normal_z: float, + is_mirror: bool, + result_x: np.ndarray, + result_y: np.ndarray, + result_z: np.ndarray, +): + """ + A :mod:`numba`-accelerated version of Snell's law. + + Parameters + ---------- + direction_x + The :math:`x` component of the propagation direction of the incident light. + direction_y + The :math:`y` component of the propagation direction of the incident light. + direction_z + The :math:`z` component of the propagation direction of the incident light. + index_refraction + The index of refraction of the current medium. + index_refraction_new + The index of refraction of the new medium. + normal_x + The :math:`x` component of the vector perpendicular to the interface. + normal_y + The :math:`y` component of the vector perpendicular to the interface. + normal_z + The :math:`z` component of the vector perpendicular to the interface. + is_mirror + Whether the incident light is reflected or not. + """ + a_x = direction_x + a_y = direction_y + a_z = direction_z + + n1 = index_refraction + n2 = index_refraction_new + + u_x = normal_x + u_y = normal_y + u_z = normal_z + + a2 = a_x * a_x + a_y * a_y + a_z * a_z + + r = n1 / n2 + r2 = r * r + + au = a_x * u_x + a_y * u_y + a_z * u_z + au2 = au * au + + sgn = -math.copysign(1, au) + + d = -au + sgn * (2 * is_mirror - 1) * math.sqrt(1 / r2 + au2 - a2) + + result_x[:] = r * (a_x + d * u_x) + result_y[:] = r * (a_y + d * u_y) + result_z[:] = r * (a_z + d * u_z) From 4961ca7c259ca04e712180c166f282c2882a9605 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Mon, 1 Dec 2025 16:05:35 -0700 Subject: [PATCH 2/5] remove wavelength --- optika/materials/_snells_law.py | 6 +----- optika/materials/_tests/test_snells_law.py | 9 --------- optika/surfaces.py | 1 - 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/optika/materials/_snells_law.py b/optika/materials/_snells_law.py index 8ea0c0b..8eb0cf6 100644 --- a/optika/materials/_snells_law.py +++ b/optika/materials/_snells_law.py @@ -39,11 +39,10 @@ def snells_law_scalar( def snells_law( - wavelength: u.Quantity | na.AbstractScalar, direction: na.AbstractCartesian3dVectorArray, index_refraction: float | na.AbstractScalar, index_refraction_new: float | na.AbstractScalar, - normal: None | na.AbstractCartesian3dVectorArray, + normal: None | na.AbstractCartesian3dVectorArray = None, is_mirror: bool | na.AbstractScalar = False, ) -> na.Cartesian3dVectorArray: r""" @@ -51,8 +50,6 @@ def snells_law( Parameters ---------- - wavelength - The wavelength of the incoming light direction The propagation direction of the incoming light index_refraction @@ -91,7 +88,6 @@ def snells_law( # Define the keyword arguments that are common # to both the reflected and transmitted ray kwargs = dict( - wavelength=350 * u.nm, direction=direction, index_refraction=1, normal=na.Cartesian3dVectorArray(0, 0, 1), diff --git a/optika/materials/_tests/test_snells_law.py b/optika/materials/_tests/test_snells_law.py index 1422d55..9ce7c2e 100644 --- a/optika/materials/_tests/test_snells_law.py +++ b/optika/materials/_tests/test_snells_law.py @@ -46,13 +46,6 @@ def test_snells_law_scalar( assert np.allclose(result, result_expected) -@pytest.mark.parametrize( - argnames="wavelength", - argvalues=[ - 350 * u.nm, - na.linspace(300 * u.nm, 400 * u.nm, axis="wavelength", num=3), - ], -) @pytest.mark.parametrize( argnames="direction", argvalues=[ @@ -87,7 +80,6 @@ def test_snells_law_scalar( ], ) def test_snells_law( - wavelength: u.Quantity | na.AbstractScalar, direction: na.AbstractCartesian3dVectorArray, index_refraction: float | na.AbstractScalar, index_refraction_new: float | na.AbstractScalar, @@ -95,7 +87,6 @@ def test_snells_law( is_mirror: bool | na.AbstractScalar, ): result = optika.materials.snells_law( - wavelength=wavelength, direction=direction, index_refraction=index_refraction, index_refraction_new=index_refraction_new, diff --git a/optika/surfaces.py b/optika/surfaces.py index 3112c96..ed1b9e1 100644 --- a/optika/surfaces.py +++ b/optika/surfaces.py @@ -163,7 +163,6 @@ def propagate_rays( wavelength_2 = wavelength_1 / r b = optika.materials.snells_law( - wavelength=wavelength_1, direction=a, index_refraction=n1, index_refraction_new=n2, From 5d2b832608479f24d8595adcf2ee93ff8b602f6f Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Mon, 1 Dec 2025 16:47:55 -0700 Subject: [PATCH 3/5] fix --- optika/materials/matrices.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optika/materials/matrices.py b/optika/materials/matrices.py index 40343d9..090e96c 100644 --- a/optika/materials/matrices.py +++ b/optika/materials/matrices.py @@ -320,7 +320,6 @@ def transfer( """ direction_internal = snells_law( - wavelength=wavelength, direction=direction, index_refraction=1, index_refraction_new=np.real(n), From 01959d97c171458ae77e52f6526d510bad7c0667 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 2 Dec 2025 09:16:00 -0700 Subject: [PATCH 4/5] docs --- optika/rulings/_spacing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optika/rulings/_spacing.py b/optika/rulings/_spacing.py index ccde71e..6885602 100644 --- a/optika/rulings/_spacing.py +++ b/optika/rulings/_spacing.py @@ -191,7 +191,6 @@ class HolographicRulingSpacing( # Compute the output direction of the diffracted rays. direction_output = optika.materials.snells_law( - wavelength=wavelength, direction=direction_input, index_refraction=1, index_refraction_new=1, From 74e424e62e4180b766fd4a37655c7f369212304c Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Tue, 2 Dec 2025 09:16:46 -0700 Subject: [PATCH 5/5] coverage --- optika/materials/_snells_law.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optika/materials/_snells_law.py b/optika/materials/_snells_law.py index 8eb0cf6..d3e964c 100644 --- a/optika/materials/_snells_law.py +++ b/optika/materials/_snells_law.py @@ -313,7 +313,7 @@ def _snells_law_numba( result_x: np.ndarray, result_y: np.ndarray, result_z: np.ndarray, -): +): # pragma: nocover """ A :mod:`numba`-accelerated version of Snell's law.