-
Notifications
You must be signed in to change notification settings - Fork 18
[SPH][Doc] add a pyvista integration example #1474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks @tdavidcl for opening this PR! You can do multiple things directly here: Once the workflow completes a message will appear displaying informations related to the run. Also the PR gets automatically reviewed by gemini, you can: |
Summary of ChangesHello @tdavidcl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new Sphinx documentation example that integrates Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Workflow reportworkflow report corresponding to commit 34e38f2 Light CI is enabled. This will only run the basic tests and not the full tests. Pre-commit check reportSome failures were detected in base source checks checks. ❌ blackSuggested changesDetailed changes :diff --git a/doc/sphinx/examples/sph/run_circular_disc_lense_thirringpyvista2.py b/doc/sphinx/examples/sph/run_circular_disc_lense_thirringpyvista2.py
index c456a4f9..56bc3a39 100644
--- a/doc/sphinx/examples/sph/run_circular_disc_lense_thirringpyvista2.py
+++ b/doc/sphinx/examples/sph/run_circular_disc_lense_thirringpyvista2.py
@@ -100,7 +100,7 @@ if shamrock.sys.world_rank() == 0:
# Sink parameters
center_racc = rin / 2.0 # [au]
-inclination = np.pi / 4
+inclination = np.pi / 4
# Viscosity parameter
@@ -417,7 +417,7 @@ pv.set_plot_theme("dark")
p, mesh_actor, text_actor = None, None, None
-base_cam_pos = [rout* 1.7, 0, rout*0.4]
+base_cam_pos = [rout * 1.7, 0, rout * 0.4]
rotate_video = True
rotation_vector = [0, 0, 1] # axis around which you rotate your POV
@@ -443,7 +443,7 @@ def analysis(ianalysis):
dic_sham = ctx.collect_data()
- #print(dic_sham["part_id"])
+ # print(dic_sham["part_id"])
global p, mesh_actor, text_actor, base_cam_pos, orig_r
if p is None:
@@ -454,7 +454,7 @@ def analysis(ianalysis):
point_cloud = pv.PolyData(dic_sham["xyz"])
r = np.linalg.norm(dic_sham["xyz"], axis=1)
- orig_r = [ 0 for i in range(len(r))]
+ orig_r = [0 for i in range(len(r))]
for i in range(len(r)):
orig_r[dic_sham["part_id"][i]] = r[i]
@@ -466,20 +466,20 @@ def analysis(ianalysis):
mesh_actor = p.add_mesh(
point_cloud,
cmap="magma_r",
- #opacity="geom",
- #clim=(-7.209004326372496, -6.96752862264146),
+ # opacity="geom",
+ # clim=(-7.209004326372496, -6.96752862264146),
render_points_as_spheres=True,
point_size=15.0,
)
p.show_bounds(
- bounds=[-rout, rout, -rout, rout, -rout*0.2, rout*0.2],
- grid='back',
- location='outer',
- ticks='both',
+ bounds=[-rout, rout, -rout, rout, -rout * 0.2, rout * 0.2],
+ grid="back",
+ location="outer",
+ ticks="both",
n_xlabels=2,
n_ylabels=2,
- n_zlabels=2
+ n_zlabels=2,
)
p.show(auto_close=False, interactive_update=True)
@@ -488,15 +488,15 @@ def analysis(ianalysis):
R = rotation_matrix(rotation_vector, -2 * np.pi * t / 45000.0)
cam_pos = R @ base_cam_pos
- #print(cam_pos, base_cam_pos)
- #p.camera_position = [cam_pos, (0, 0, 0), (0, 0, 1)] #'iso'
+ # print(cam_pos, base_cam_pos)
+ # p.camera_position = [cam_pos, (0, 0, 0), (0, 0, 1)] #'iso'
# update *existing* mesh, don't recreate the plotter
new_cloud = pv.PolyData(dic_sham["xyz"])
new_cloud["original r [code unit]"] = orig_r[dic_sham["part_id"]]
tmp = np.log10(pmass * (1.2 / dic_sham["hpart"]) ** (1 / 3))
- #print(tmp.min(), tmp.max())
+ # print(tmp.min(), tmp.max())
# overwrite the actor's mesh in place
mesh_actor.mapper.SetInputData(new_cloud)
@@ -512,7 +512,7 @@ def analysis(ianalysis):
p.render() # use render(), not update()
p.update() # optional — processes UI events
- #print(p.camera_position)
+ # print(p.camera_position)
p.show(auto_close=False, interactive_update=True)
# p.show(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a new example script demonstrating a black hole disc simulation with PyVista integration for on-the-fly visualization. The script is a great addition, but it has several issues that need to be addressed. The most critical issue is that the visualization part is not compatible with MPI, which will cause problems in parallel runs. There are also several medium-severity issues related to code style, maintainability, and correctness, such as inconsistent path handling, use of magic numbers, and dead code. I've provided specific comments and suggestions to resolve these issues.
| def analysis(ianalysis): | ||
|
|
||
| dic_sham = ctx.collect_data() | ||
|
|
||
| #print(dic_sham["part_id"]) | ||
|
|
||
| global p, mesh_actor, text_actor, base_cam_pos, orig_r | ||
| if p is None: | ||
| p = pv.Plotter() | ||
|
|
||
| p.camera_position = (base_cam_pos[0], base_cam_pos[1], base_cam_pos[2]) #'iso' | ||
|
|
||
| point_cloud = pv.PolyData(dic_sham["xyz"]) | ||
|
|
||
| r = np.linalg.norm(dic_sham["xyz"], axis=1) | ||
| orig_r = [ 0 for i in range(len(r))] | ||
|
|
||
| for i in range(len(r)): | ||
| orig_r[dic_sham["part_id"][i]] = r[i] | ||
|
|
||
| orig_r = np.array(orig_r) | ||
|
|
||
| point_cloud["original r [code unit]"] = orig_r[dic_sham["part_id"]] | ||
|
|
||
| mesh_actor = p.add_mesh( | ||
| point_cloud, | ||
| cmap="magma_r", | ||
| #opacity="geom", | ||
| #clim=(-7.209004326372496, -6.96752862264146), | ||
| render_points_as_spheres=True, | ||
| point_size=15.0, | ||
| ) | ||
|
|
||
| p.show_bounds( | ||
| bounds=[-rout, rout, -rout, rout, -rout*0.2, rout*0.2], | ||
| grid='back', | ||
| location='outer', | ||
| ticks='both', | ||
| n_xlabels=2, | ||
| n_ylabels=2, | ||
| n_zlabels=2 | ||
| ) | ||
|
|
||
| p.show(auto_close=False, interactive_update=True) | ||
|
|
||
| t = model.get_time() + 3000 | ||
|
|
||
| R = rotation_matrix(rotation_vector, -2 * np.pi * t / 45000.0) | ||
| cam_pos = R @ base_cam_pos | ||
| #print(cam_pos, base_cam_pos) | ||
| #p.camera_position = [cam_pos, (0, 0, 0), (0, 0, 1)] #'iso' | ||
|
|
||
| # update *existing* mesh, don't recreate the plotter | ||
| new_cloud = pv.PolyData(dic_sham["xyz"]) | ||
| new_cloud["original r [code unit]"] = orig_r[dic_sham["part_id"]] | ||
|
|
||
| tmp = np.log10(pmass * (1.2 / dic_sham["hpart"]) ** (1 / 3)) | ||
| #print(tmp.min(), tmp.max()) | ||
|
|
||
| # overwrite the actor's mesh in place | ||
| mesh_actor.mapper.SetInputData(new_cloud) | ||
|
|
||
| if text_actor is not None: | ||
| text_actor.SetVisibility(0) | ||
|
|
||
| text_actor = p.add_text( | ||
| "t = {:.03f} [code unit] dt = {:.03f} [code unit]".format(model.get_time(), model.get_dt()) | ||
| ) | ||
| gc.collect() | ||
| # update rendering | ||
| p.render() # use render(), not update() | ||
| p.update() # optional — processes UI events | ||
|
|
||
| #print(p.camera_position) | ||
|
|
||
| p.show(auto_close=False, interactive_update=True) | ||
| # p.show( | ||
| # screenshot=plot_folder + "pyvista_{:05d}.png".format(ianalysis), | ||
| # window_size=[1920, 1080], | ||
| # auto_close=False, | ||
| # interactive_update=True, | ||
| # ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The analysis function has several critical issues that will cause it to fail or behave incorrectly when run in parallel with MPI:
- PyVista on all ranks: The PyVista plotting logic is not guarded by a rank check. This will cause each MPI rank to create a
pv.Plotter()instance, leading to multiple (and likely undesired) plot windows opening. - Incorrect
orig_rcalculation: Theorig_rarray is initialized with a size based on the local number of particles (len(r)), but it's indexed using globalpart_idvalues. In an MPI run, this will almost certainly cause anIndexErrorbecause a globalpart_idcan be larger than the local particle count. Theorig_rarray should be sized with the total number of particles (Npart) and populated by gathering data from all ranks. - Inefficient array creation: The creation of
orig_rusing a list comprehension and aforloop is inefficient for a large number of particles. Usingnumpyvectorized operations would be significantly faster. - Misleading
orig_rcomputation: The variable is namedorig_r, suggesting it stores the initial radius. However, it's computed during the first call toanalysis, not at the beginning of the simulation. If this is the intended behavior, the variable name could be more descriptive. If it's meant to be the initial radius, it should be computed once, right after particle generation.
The entire analysis function should be refactored to be MPI-aware, likely by gathering all necessary data to rank 0 and performing the visualization there.
| sim_folder = f"_to_trash/black_hole_disc_lense_thirring_{Npart}/" | ||
|
|
||
| dump_folder = sim_folder + "dump/" | ||
| analysis_folder = sim_folder + "analysis/" | ||
| plot_folder = analysis_folder + "plots/" | ||
|
|
||
| dump_prefix = dump_folder + "dump_" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The script constructs file paths using string concatenation with /. This is not portable across different operating systems (e.g., Windows uses \). It also leads to double slashes in paths (e.g., sim_folder ends with / and it's concatenated with "dump/"), which, while often tolerated by filesystems, is unclean.
It's better to use os.path.join() for all path constructions to ensure portability and correctness. The script is already using os.path.join() in the save_analysis_data function, so for consistency, it should be used everywhere.
This change should be applied to all manual path concatenations in the file, for example in save_rho_integ and save_vxyz_integ as well.
| sim_folder = f"_to_trash/black_hole_disc_lense_thirring_{Npart}/" | |
| dump_folder = sim_folder + "dump/" | |
| analysis_folder = sim_folder + "analysis/" | |
| plot_folder = analysis_folder + "plots/" | |
| dump_prefix = dump_folder + "dump_" | |
| sim_folder = f"_to_trash/black_hole_disc_lense_thirring_{Npart}" | |
| dump_folder = os.path.join(sim_folder, "dump") | |
| analysis_folder = os.path.join(sim_folder, "analysis") | |
| plot_folder = os.path.join(analysis_folder, "plots") | |
| dump_prefix = os.path.join(dump_folder, "dump_") |
| def get_last_dump(): | ||
| res = glob.glob(dump_prefix + "*.sham") | ||
|
|
||
| num_max = -1 | ||
|
|
||
| for f in res: | ||
| try: | ||
| dump_num = int(f[len(dump_prefix) : -5]) | ||
| if dump_num > num_max: | ||
| num_max = dump_num | ||
| except ValueError: | ||
| pass | ||
|
|
||
| if num_max == -1: | ||
| return None | ||
| else: | ||
| return num_max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of get_last_dump uses slicing with a magic number (f[len(dump_prefix) : -5]) to extract the dump number. This is fragile and depends on the length of the file extension (.sham). If the extension changes, this code will break.
For Python 3.9+, you can use str.removeprefix() and str.removesuffix() for a more robust and readable implementation. If you need to support older Python versions, combining os.path.basename() and os.path.splitext() would be a more robust alternative to magic number slicing.
| def get_last_dump(): | |
| res = glob.glob(dump_prefix + "*.sham") | |
| num_max = -1 | |
| for f in res: | |
| try: | |
| dump_num = int(f[len(dump_prefix) : -5]) | |
| if dump_num > num_max: | |
| num_max = dump_num | |
| except ValueError: | |
| pass | |
| if num_max == -1: | |
| return None | |
| else: | |
| return num_max | |
| def get_last_dump(): | |
| res = glob.glob(dump_prefix + "*.sham") | |
| num_max = -1 | |
| for f in res: | |
| try: | |
| dump_num_str = f.removeprefix(dump_prefix).removesuffix(".sham") | |
| dump_num = int(dump_num_str) | |
| if dump_num > num_max: | |
| num_max = dump_num | |
| except ValueError: | |
| pass | |
| if num_max == -1: | |
| return None | |
| else: | |
| return num_max |
| t = model.get_time() + 3000 | ||
|
|
||
| R = rotation_matrix(rotation_vector, -2 * np.pi * t / 45000.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numbers 3000 and 45000.0 are magic numbers, which make the code harder to read and maintain. They should be defined as named constants at the top of the script with comments explaining their purpose. This would clarify their meaning in the context of the camera rotation calculation.
| t = model.get_time() + 3000 | |
| R = rotation_matrix(rotation_vector, -2 * np.pi * t / 45000.0) | |
| # It's a good practice to define these as constants at the top of the script, e.g.: | |
| # CAMERA_TIME_OFFSET = 3000.0 | |
| # CAMERA_ROTATION_PERIOD = 45000.0 | |
| t = model.get_time() + 3000 | |
| R = rotation_matrix(rotation_vector, -2 * np.pi * t / 45000.0) |
| new_cloud = pv.PolyData(dic_sham["xyz"]) | ||
| new_cloud["original r [code unit]"] = orig_r[dic_sham["part_id"]] | ||
|
|
||
| tmp = np.log10(pmass * (1.2 / dic_sham["hpart"]) ** (1 / 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| text_actor = p.add_text( | ||
| "t = {:.03f} [code unit] dt = {:.03f} [code unit]".format(model.get_time(), model.get_dt()) | ||
| ) | ||
| gc.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly calling gc.collect() is generally not recommended as it can negatively impact performance and may hide underlying memory management issues. The Python garbage collector is usually efficient enough on its own. If this call is necessary to work around a specific memory leak in a library (e.g., PyVista in a loop), it should be accompanied by a comment explaining why it's needed. Otherwise, it should be removed.
|
|
||
| analysis(iplot) | ||
| iplot += 1 | ||
| ttarget += 0.75 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value 0.75 is a magic number. It appears to be the time interval for the analysis steps. It should be defined as a named constant at the top of the script (e.g., ANALYSIS_DT = 0.75) to improve readability and make it easier to change.
| ttarget += 0.75 | |
| ttarget += 0.75 # TODO: Use a named constant, e.g., ANALYSIS_DT |
No description provided.