Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 95 additions & 21 deletions optika/materials/_snells_law.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import math
import numpy as np
import numba as nb
import astropy.units as u
import named_arrays as na

Expand Down Expand Up @@ -38,20 +39,17 @@ def snells_law_scalar(


def snells_law(
wavelength: u.Quantity | na.AbstractScalar,
direction: na.AbstractCartesian3dVectorArray,
index_refraction: float | na.AbstractScalar,
index_refraction_new: float | na.AbstractScalar,
normal: None | na.AbstractCartesian3dVectorArray,
normal: None | na.AbstractCartesian3dVectorArray = None,
is_mirror: bool | na.AbstractScalar = False,
) -> na.Cartesian3dVectorArray:
r"""
A `vector form of Snell's law <https://en.wikipedia.org/wiki/Snell%27s_law#Vector_form>`_.

Parameters
----------
wavelength
The wavelength of the incoming light
direction
The propagation direction of the incoming light
index_refraction
Expand Down Expand Up @@ -90,7 +88,6 @@ def snells_law(
# Define the keyword arguments that are common
# to both the reflected and transmitted ray
kwargs = dict(
wavelength=350 * u.nm,
direction=direction,
index_refraction=1,
normal=na.Cartesian3dVectorArray(0, 0, 1),
Expand Down Expand Up @@ -271,22 +268,99 @@ def snells_law(
\pm \sqrt{\left( n_2 / n_1 \right)^2 + (\mathbf{k}_1 \cdot \hat{\mathbf{n}})^2 - k_1^2 }
\right) \hat{\mathbf{n}} \right]}
"""
a = direction
n1 = index_refraction # noqa: F841
n2 = index_refraction_new # noqa: F841

if normal is None:
normal = na.Cartesian3dVectorArray(0, 0, -1)

a_x = a.x # noqa: F841
a_y = a.y # noqa: F841
a_z = a.z # noqa: F841
u_x = normal.x # noqa: F841
u_y = normal.y # noqa: F841
u_z = normal.z # noqa: F841

return na.numexpr.evaluate(
"(n1 / n2) * (a + (-(a_x*u_x + a_y*u_y + a_z*u_z) + sign(-(a_x*u_x + a_y*u_y + a_z*u_z)) "
"* (2 * is_mirror - 1) * sqrt(1 / (n1 / n2)**2 + (a_x*u_x + a_y*u_y + a_z*u_z)**2"
"- (a_x*a_x + a_y*a_y + a_z*a_z))) * normal)"
direction = direction << u.dimensionless_unscaled
index_refraction = index_refraction << u.dimensionless_unscaled
index_refraction_new = index_refraction_new << u.dimensionless_unscaled
normal = normal << u.dimensionless_unscaled

b_x, b_y, b_z = _snells_law_numba(
direction.x.value,
direction.y.value,
direction.z.value,
index_refraction.value,
index_refraction_new.value,
normal.x.value,
normal.y.value,
normal.z.value,
is_mirror,
)

return na.Cartesian3dVectorArray(b_x, b_y, b_z)


@nb.guvectorize(
[
"void(float64,float64,float64,float64,float64,float64,float64,float64,bool,float64[:],float64[:],float64[:])"
],
"(),(),(),(),(),(),(),(),()->(),(),()",
target="parallel",
nopython=True,
cache=True,
)
def _snells_law_numba(
direction_x: float,
direction_y: float,
direction_z: float,
index_refraction: float,
index_refraction_new: float,
normal_x: float,
normal_y: float,
normal_z: float,
is_mirror: bool,
result_x: np.ndarray,
result_y: np.ndarray,
result_z: np.ndarray,
): # pragma: nocover
"""
A :mod:`numba`-accelerated version of Snell's law.

Parameters
----------
direction_x
The :math:`x` component of the propagation direction of the incident light.
direction_y
The :math:`y` component of the propagation direction of the incident light.
direction_z
The :math:`z` component of the propagation direction of the incident light.
index_refraction
The index of refraction of the current medium.
index_refraction_new
The index of refraction of the new medium.
normal_x
The :math:`x` component of the vector perpendicular to the interface.
normal_y
The :math:`y` component of the vector perpendicular to the interface.
normal_z
The :math:`z` component of the vector perpendicular to the interface.
is_mirror
Whether the incident light is reflected or not.
"""
a_x = direction_x
a_y = direction_y
a_z = direction_z

n1 = index_refraction
n2 = index_refraction_new

u_x = normal_x
u_y = normal_y
u_z = normal_z

a2 = a_x * a_x + a_y * a_y + a_z * a_z

r = n1 / n2
r2 = r * r

au = a_x * u_x + a_y * u_y + a_z * u_z
au2 = au * au

sgn = -math.copysign(1, au)

d = -au + sgn * (2 * is_mirror - 1) * math.sqrt(1 / r2 + au2 - a2)

result_x[:] = r * (a_x + d * u_x)
result_y[:] = r * (a_y + d * u_y)
result_z[:] = r * (a_z + d * u_z)
9 changes: 0 additions & 9 deletions optika/materials/_tests/test_snells_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ def test_snells_law_scalar(
assert np.allclose(result, result_expected)


@pytest.mark.parametrize(
argnames="wavelength",
argvalues=[
350 * u.nm,
na.linspace(300 * u.nm, 400 * u.nm, axis="wavelength", num=3),
],
)
@pytest.mark.parametrize(
argnames="direction",
argvalues=[
Expand Down Expand Up @@ -87,15 +80,13 @@ def test_snells_law_scalar(
],
)
def test_snells_law(
wavelength: u.Quantity | na.AbstractScalar,
direction: na.AbstractCartesian3dVectorArray,
index_refraction: float | na.AbstractScalar,
index_refraction_new: float | na.AbstractScalar,
normal: None | na.AbstractCartesian3dVectorArray,
is_mirror: bool | na.AbstractScalar,
):
result = optika.materials.snells_law(
wavelength=wavelength,
direction=direction,
index_refraction=index_refraction,
index_refraction_new=index_refraction_new,
Expand Down
1 change: 0 additions & 1 deletion optika/materials/matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def transfer(
"""

direction_internal = snells_law(
wavelength=wavelength,
direction=direction,
index_refraction=1,
index_refraction_new=np.real(n),
Expand Down
1 change: 0 additions & 1 deletion optika/rulings/_spacing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ class HolographicRulingSpacing(

# Compute the output direction of the diffracted rays.
direction_output = optika.materials.snells_law(
wavelength=wavelength,
direction=direction_input,
index_refraction=1,
index_refraction_new=1,
Expand Down
1 change: 0 additions & 1 deletion optika/surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def propagate_rays(
wavelength_2 = wavelength_1 / r

b = optika.materials.snells_law(
wavelength=wavelength_1,
direction=a,
index_refraction=n1,
index_refraction_new=n2,
Expand Down