From b4dd727cc3a559bdd102e5eb13b7b714355849dd Mon Sep 17 00:00:00 2001 From: Songyuan Tang Date: Sat, 14 Mar 2026 21:32:26 -0500 Subject: [PATCH 1/5] added AI-based center of rotation detection as a method in find_center --- src/tomocupy/__init__.py | 1 + src/tomocupy/__main__.py | 37 +- src/tomocupy/ai/__init__.py | 1 + src/tomocupy/ai/inference.py | 136 +++ src/tomocupy/ai/model_archs.py | 781 ++++++++++++++++++ src/tomocupy/config.py | 72 +- src/tomocupy/find_center.py | 5 + src/tomocupy/rec.py | 17 +- src/tomocupy/rec_steps.py | 11 +- .../reconstruction/backproj_parallel.py | 47 +- 10 files changed, 1092 insertions(+), 16 deletions(-) create mode 100644 src/tomocupy/ai/__init__.py create mode 100644 src/tomocupy/ai/inference.py create mode 100644 src/tomocupy/ai/model_archs.py diff --git a/src/tomocupy/__init__.py b/src/tomocupy/__init__.py index cdd9843..6e16c5c 100644 --- a/src/tomocupy/__init__.py +++ b/src/tomocupy/__init__.py @@ -58,3 +58,4 @@ from tomocupy.processing import * from tomocupy.dataio import * from tomocupy.global_vars import * +from tomocupy.ai import * diff --git a/src/tomocupy/__main__.py b/src/tomocupy/__main__.py index 96f0edb..442d59c 100644 --- a/src/tomocupy/__main__.py +++ b/src/tomocupy/__main__.py @@ -86,20 +86,34 @@ def run_rec(args, cl_reader, cl_writer): args.rotate_proj_angle = 0 args.lamino_angle = 0 # rotation axis search - if args.rotation_axis_auto == 'auto': + if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method != 'ai'): clrotthandle = FindCenter(cl_reader) args.rotation_axis = clrotthandle.find_center() params.center = args.rotation_axis log.warning(f'set rotaion axis {args.rotation_axis}') # create reconstruction object and run reconstruction - clpthandle = GPURec(cl_reader, cl_writer) + if (args.reconstruction_type == 'try') and (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): + cache_to_infer = True + else: + cache_to_infer = False + clpthandle = GPURec(cl_reader, cl_writer, cache_to_infer=cache_to_infer) if args.reconstruction_type == 'full': clpthandle.recon_all() if args.reconstruction_type == 'try': - clpthandle.recon_try() + if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): + img_cache, center_of_rotation_cache, id_slice_cache = clpthandle.recon_try() + clrotthandle = FindCenter(cl_reader) + + args.rotation_axis = clrotthandle.find_center_ai(args, img_cache, center_of_rotation_cache, params.fnameout[:-6]) + params.center = args.rotation_axis + log.warning(f'set rotaion axis {args.rotation_axis}') + else: + clpthandle.recon_try() + + rec_time = (time.time()-t) log.warning(f'Reconstruction time {rec_time:.1e}s') @@ -112,15 +126,26 @@ def run_recsteps(args, cl_reader, cl_writer): exit() t = time.time() - if args.rotation_axis_auto == 'auto': + if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method != 'ai'): clrotthandle = FindCenter(cl_reader) args.rotation_axis = clrotthandle.find_center() params.center = args.rotation_axis log.warning(f'set rotaion axis {args.rotation_axis}') - clpthandle = GPURecSteps(cl_reader, cl_writer) + if (args.reconstruction_type == 'try') and (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): + cache_to_infer = True + else: + cache_to_infer = False + clpthandle = GPURecSteps(cl_reader, cl_writer,cache_to_infer=cache_to_infer) # does all preprocessing for both full and try reconstructions - clpthandle.recon_steps_all() + if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): + img_cache, center_of_rotation_cache, id_slice_cache = clpthandle.recon_steps_all() + clrotthandle = FindCenter(cl_reader) + args.rotation_axis = clrotthandle.find_center_ai(args, img_cache, center_of_rotation_cache, params.fnameout[:-6]) + params.center = args.rotation_axis + log.warning(f'set rotaion axis {args.rotation_axis}') + else: + clpthandle.recon_steps_all() log.warning(f'Reconstruction time {(time.time()-t):.01f}s') diff --git a/src/tomocupy/ai/__init__.py b/src/tomocupy/ai/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/src/tomocupy/ai/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/src/tomocupy/ai/inference.py b/src/tomocupy/ai/inference.py new file mode 100644 index 0000000..d75e5b3 --- /dev/null +++ b/src/tomocupy/ai/inference.py @@ -0,0 +1,136 @@ +import time +import torch +import numpy as np +from pathlib import Path +from PIL import Image + +from tomocupy.ai.model_archs import ClassificationModel, _make_dinov2_model + +def sample_patch_corner(mask,window_size,num_windows): + sample_patch_probs = (mask / mask.sum()).reshape((-1,1)).squeeze().astype(np.float64) + grid_indices = np.where(np.random.multinomial(1,sample_patch_probs/sample_patch_probs.sum(),num_windows))[1] + patch_corners = [] + for grid_idx in grid_indices: + grid_idx_ = [] + img_grids = np.indices(mask.shape) + for d in range(len(list(mask.shape))): + grid_idx_.append(img_grids[d].reshape((-1,1)).squeeze()[grid_idx]) + if grid_idx_[-1] == 0: + grid_idx_ = grid_idx_[:-1] + patch_corner = [grid_idx_[i]-window_size//2 for i in range(len(grid_idx_))] + patch_corner = [max(0, pc) for pc in patch_corner] + patch_corner = [min(pc, mask.shape[i] - window_size - 1) for i, pc in enumerate(patch_corner)] + patch_corner = tuple(patch_corner) + patch_corners.append(patch_corner) + + return patch_corners + +def inference_pipeline(args, img_cache_original, center_of_rotation_cache, out_dir, preprocessed=False): + + use_8bits = args.infer_use_8bits + downsample_factors = args.infer_downsample_factor + nums_windows = args.infer_num_windows + szs = args.infer_window_size + assert isinstance(downsample_factors,list) + assert isinstance(nums_windows,list) + assert isinstance(szs,list) + seed_number = args.infer_seed_number + model_path = args.infer_model_path + if len(nums_windows)>1: + multi_instances = True + elif len(nums_windows)==1 and nums_windows[0]>1: + multi_instances = True + else: + multi_instances = False + + np.random.seed(seed_number) + device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' + + model_ = _make_dinov2_model() + model = ClassificationModel(model_,embed_dim=model_.embed_dim,num_windows=nums_windows,multi_instances=multi_instances) + states = torch.load(model_path, map_location='cpu')['state_dict'] + states = {(k.replace("module.", "") if "module." in k else k): v for k, v in states.items()} + msg = model.load_state_dict(states,strict=False) + model.to(device) + + print('starting model inference...') + t_start3 = time.time() + + imgs_cache = [] + for downsample_factor in downsample_factors: + if downsample_factor > 1: + print(f"Resizing with downsample factor {downsample_factor}.") + else: + print(f"Downsample factor is {downsample_factor}. No resizing applied.") + if use_8bits: + print("Requantizing using 8 bits.") + img_cache = [] + + for img_ in img_cache_original: + if not preprocessed: + if downsample_factor>1: + + img_ = Image.fromarray(img_,mode='F') + img_array = np.array(img_.resize((img_.size[0]//downsample_factor,img_.size[1]//downsample_factor),Image.BILINEAR),dtype=np.float32) + else: + + img_array = img_.copy().astype(np.float32) + + img_array = ((img_array - img_array.min()) / (img_array.max() - img_array.min() + 1e-8)) + + if use_8bits: + + img_array = (img_array * 255).astype(np.uint8) + img_array = img_array.astype(np.float32) / 255. + else: + img_array = img_.copy().astype(np.float32) + img_cache.append(img_array[None,...]) + img_cache = np.concatenate(img_cache,axis=0) + imgs_cache.append(img_cache) + + if multi_instances: + patches_corners = [] + for img_cache,num_windows,sz in zip(imgs_cache,nums_windows,szs): + row, col = img_cache.shape[1:] + x_coords, y_coords = np.meshgrid(np.arange(col)-(col-1)/2, np.arange(row)-(row-1)/2) + mask = (x_coords**2+y_coords**2) <= ((row-1) / 2)**2 + patch_corners = sample_patch_corner(mask,sz,num_windows) + patches_corners.append(patch_corners) + else: + row, col = imgs_cache[0].shape[1:] + sz = szs[0] + patches_corners = [(row//2-sz//2, col//2-sz//2)] + + features = [] + + + for idx in range(imgs_cache[0].shape[0]): + samples = [] + for img_cache,patch_corners,sz in zip(imgs_cache,patches_corners,szs): + img_array = img_cache[idx] + if multi_instances: + imgs = [] + for patch_corner in patch_corners: + img = img_array[patch_corner[0]:patch_corner[0]+sz,patch_corner[1]:patch_corner[1]+sz] + img = torch.from_numpy(img).to(device=device,dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(0) + imgs.append(img) + sample = {'images':torch.cat(imgs,dim=1)} + else: + img = img_array[patch_corner[0]:patch_corner[0]+sz,patch_corner[1]:patch_corner[1]+sz] + img = torch.from_numpy(img).to(device=device,dtype=torch.float32).unsqueeze(0).unsqueeze(0).unsqueeze(0) + sample = {'images': img} + samples.append(sample) + with torch.no_grad(): + feature = model(samples) + features.append(feature) + t_stop3 = time.time() + print(f"done. Elapsed time is {t_stop3-t_start3} s.") + + features_all = torch.cat(features,dim=0).detach().cpu().numpy() + if args.infer_save_intermediate_data: + np.savez(Path(out_dir)/'predicts_all',features_all,center_of_rotation_cache) + scores = np.exp(features_all[:,1])/(np.exp(features_all[:,0])+np.exp(features_all[:,1])) + centers_of_rotation = [center_of_rotation_cache[i] for i in np.where(scores==max(scores))[0]] + with open(Path(out_dir)/'center_of_rotation.txt','a') as f: + for cor in centers_of_rotation: + f.write(f"{cor:.1f}\n") \ No newline at end of file diff --git a/src/tomocupy/ai/model_archs.py b/src/tomocupy/ai/model_archs.py new file mode 100644 index 0000000..22055d2 --- /dev/null +++ b/src/tomocupy/ai/model_archs.py @@ -0,0 +1,781 @@ +# dinov2 ViT +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +# attention pooling +# MIT License + +# Copyright (c) 2018 Maximilian Ilse and Jakub Tomczak + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from pathlib import Path +from einops import rearrange +import numpy as np +from functools import partial +import math + +from typing import Sequence, Tuple, Union, Callable, Optional, List + + +def _make_dinov2_model(img_size:int=518,patch_size:int=14,init_values:float=1.0,ffn_layer:str='mlp',block_chunks: int = 0,\ + num_register_tokens:int= 0,interpolate_antialias:bool=False,interpolate_offset:float=0.1): + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + return vit_base(**vit_kwargs) + + +def vit_base(patch_size=16, num_register_tokens=0, in_chans=3, channel_adaptive=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=Attention),#MemEffAttention), + num_register_tokens=num_register_tokens, + in_chans=in_chans, + channel_adaptive=channel_adaptive, + **kwargs, + ) + return model + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def init_weights( + self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0 + ) -> None: + init_attn_std = init_attn_std or (self.dim**-0.5) + init_proj_std = init_proj_std or init_attn_std * factor + nn.init.normal_(self.qkv.weight, std=init_attn_std) + nn.init.normal_(self.proj.weight, std=init_proj_std) + if self.qkv.bias is not None: + nn.init.zeros_(self.qkv.bias) + if self.proj.bias is not None: + nn.init.zeros_(self.proj.bias) + + def forward(self, x: Tensor, is_causal: bool = False) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = torch.unbind(qkv, 2) + q, k, v = [t.transpose(1, 2) for t in [q, k, v]] + x = nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal + ) + x = x.transpose(1, 2).contiguous().view(B, N, C) + x = self.proj_drop(self.proj(x)) + return x + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + channel_adaptive=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.bag_of_channels = channel_adaptive + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = np.linspace(0, drop_path_rate, depth).tolist() # stochastic depth decay rule + #TODO: add logger and implement other ffn types + if ffn_layer == "mlp": + # logger.info("using MLP layer as FFN") + ffn_layer = Mlp + # elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + # logger.info("using SwiGLU layer as FFN") + # ffn_layer = SwiGLUFFNFused + # elif ffn_layer == "identity": + # logger.info("using Identity layer as FFN") + + # def f(*args, **kwargs): + # return nn.Identity() + + # ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + + if self.bag_of_channels: + B, C, H, W = x.shape + x = x.reshape(B * C, 1, H, W) # passing channels to batch dimension to get encodings for each channel + + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + + if self.bag_of_channels: + output = tuple(zip(outputs, class_tokens)) + output = list( + zip(*output) + ) # unzip the tuple: (list of patch_tokens per block, list of class tokens per block) + patch_tokens_per_block = output[0] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, N, D + cls_tokens_per_block = output[1] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B*C, D + patch_tokens_per_block = [ + patch_tokens.reshape(B, C, patch_tokens.shape[-2], patch_tokens.shape[-1]) + for patch_tokens in patch_tokens_per_block + ] # [BLOCK1, BLOCK2, ...] where BLOCK1.shape: B, C, N, D + cls_tokens_per_block = [cls_tokens.reshape(B, -1) for cls_tokens in cls_tokens_per_block] + output = tuple(zip(patch_tokens_per_block, cls_tokens_per_block)) + return output + + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) +# class ClassificationModel(nn.Module): +# def __init__(self, model, num_classes=2): +# super().__init__() +# self.model = model +# self.num_classes = num_classes +# self.head = nn.Linear(model.embed_dim, num_classes) + +# def forward(self, x): +# features = self.model(x) +# return self.head(features) + +class ClassificationModel(nn.Module): + def __init__(self, model, embed_dim:int, num_windows:List[int], num_classes:int=2, multi_instances:bool=False,\ + attn_branches:int=1, attn_embed_dim:int=None): + super().__init__() + self.model = model + self.embed_dim = embed_dim + self.num_windows = num_windows + self.num_classes = num_classes + self.multi_instances = multi_instances + + if multi_instances: + if attn_embed_dim is None: + attn_embed_dim = embed_dim + assert type(attn_embed_dim) is int + + self.attention = nn.Sequential( + nn.Linear(embed_dim, attn_embed_dim), # matrix V + nn.Tanh(), + ) + self.gate = nn.Sequential( + nn.Linear(embed_dim, attn_embed_dim), # matrix U + nn.Sigmoid(), + ) + self.fc = nn.Linear(attn_embed_dim, attn_branches) + + self.attn_embed_dim = attn_embed_dim + self.attn_branches = attn_branches + + self.head = nn.Linear(embed_dim, num_classes) + + def load_weights(self,model_path, replace_pattern="module."): + if Path(model_path).suffix == '.pt': + states = torch.load(model_path, map_location='cpu')['state_dict'] + elif Path(model_path).suffix == '.pth': + states = torch.load(model_path, map_location='cpu')['model'] + states = {(k.replace(replace_pattern, "") if replace_pattern in k else k): v for k, v in states.items()} + msg = self.model.load_state_dict(states,strict=False) + print(f"missing keys: {msg.missing_keys}") + print(f"unexpected keys: {msg.unexpected_keys}") + + def forward(self, sample): + self.model.eval() + assert len(sample) == len(self.num_windows) + # if self.model is not None: + if len(self.num_windows) == 1: + + images = sample[0]['images'] + + with torch.no_grad(): + if not self.multi_instances: + assert self.num_windows[0] == 1 + features_all = self.model(images[:,0].repeat(1,3,1,1)) + else: + features_all = self.model(rearrange(images,'b k c h w -> (b k) c h w').repeat(1,3,1,1)) + features_all = rearrange(features_all,'(b k) c -> b k c', k=self.num_windows[0]) + elif len(self.num_windows) > 1: + assert self.multi_instances + features_all = [] + for idx_,sample_ in enumerate(sample): + images = sample_['images'] + with torch.no_grad(): + features_ = self.model(rearrange(images,'b k c h w -> (b k) c h w').repeat(1,3,1,1)) + features_ = rearrange(features_,'(b k) c -> b k c', k=self.num_windows[idx_]) + features_all.append(features_) + features_all = torch.cat(features_all,dim=1) + + + # else: + # features_ = sample['features_'] + + if self.multi_instances: + attn = self.fc(self.attention(features_all) * self.gate(features_all)) #features_ is b*k*c + attn = torch.transpose(attn, 2, 1) #attn is b*ATTENTION_BRANCHES*K after transposition + attn = F.softmax(attn, dim=2) # softmax over K + return torch.mean(self.head(torch.bmm(attn,features_all)),dim=1) + else: + return self.head(features_all) #features_ is b*c \ No newline at end of file diff --git a/src/tomocupy/config.py b/src/tomocupy/config.py index e8f7bcd..a942b6c 100644 --- a/src/tomocupy/config.py +++ b/src/tomocupy/config.py @@ -60,6 +60,11 @@ log = logging.getLogger(__name__) +def list_of_ints(arg): + if ',' in arg: + return [int(val) for val in (arg.split(','))] + else: + return [int(arg)] def default_parameter(func, param): """Get the default value for a function parameter. @@ -386,7 +391,7 @@ def default_parameter(func, param): 'default': 'sift', 'type': str, 'help': "Method for automatic rotation search.", - 'choices': ['sift', 'vo']}, + 'choices': ['sift', 'vo','ai']}, 'find-center-start-row': { 'type': int, 'default': 0, @@ -562,14 +567,73 @@ def default_parameter(func, param): } +SECTIONS['inference'] = { + 'infer-seed-number': { + 'default': 10, + 'type': int, + 'help': "Seed number for random number generator" + }, + 'infer-input-data-type': { + 'default': 'raw', + 'type': str, + 'help': "Center of rotation algorithm input", + 'choices': ['raw','try'] + }, + 'infer-use-8bits': { + 'default': True, + 'help': "When set requantize the pixel values with 8 bits", + 'action': 'store_true' + }, + 'infer-downsample-factor': { + 'default': [1], + 'type': list_of_ints, + 'help': "Downsample factor applied to the try reconstruction slices" + }, + 'infer-num-windows': { + 'default': [3], + 'type': list_of_ints, + 'help': "Number of windows to aggregate try recon image features" + }, + 'infer-window-size': { + 'default': [518], + 'type': list_of_ints, + 'help': "Size of each square window to aggregate try recon image features" + }, + 'infer-model-path': { + 'default': '/home/beams/TANGS/conda/tomocor_models/datav2_518_full_finetune/epoch_10.pt', + 'type': str, + 'help': "Path to the trained model weights" + }, + 'infer-input-dir': { + 'default': 'none', + 'type': str, + 'help': 'Input directory if tiff images are the direct inputs' + }, + 'infer-save-intermediate-data': { + 'default': False, + 'help': 'When set save the per-slice model predictions', + 'action': 'store_true' + }, + 'infer-batch-list': { + 'default': None, + 'type': str, + 'help': 'When supplied txt file process a list of input directories', + }, + 'infer-out-dir-name': { + 'default': None, + 'type': Path, + 'help': "Directory for output batches", + 'metavar': 'PATH' + }, +} RECON_PARAMS = ('file-reading', 'remove-stripe', - 'reconstruction', 'fw', 'ti', 'vo-all', 'lamino', 'reconstruction-types', 'beam-hardening') + 'reconstruction', 'fw', 'ti', 'vo-all', 'lamino', 'reconstruction-types', 'beam-hardening', 'inference') RECON_STEPS_PARAMS = ('file-reading', 'remove-stripe', 'reconstruction', - 'retrieve-phase', 'fw', 'ti', 'vo-all', 'lamino', 'reconstruction-steps-types', 'rotate-proj', 'beam-hardening') + 'retrieve-phase', 'fw', 'ti', 'vo-all', 'lamino', 'reconstruction-steps-types', 'rotate-proj', 'beam-hardening','inference') NICE_NAMES = ('General', 'File reading', 'Remove stripe', - 'Remove stripe FW', 'Remove stripe Titarenko', 'Remove stripe Vo', 'Retrieve phase', 'Reconstruction') + 'Remove stripe FW', 'Remove stripe Titarenko', 'Remove stripe Vo', 'Retrieve phase', 'Reconstruction','Inference') def get_config_name(): diff --git a/src/tomocupy/find_center.py b/src/tomocupy/find_center.py index 1dcce36..7810893 100644 --- a/src/tomocupy/find_center.py +++ b/src/tomocupy/find_center.py @@ -51,6 +51,8 @@ import signal import cv2 +from tomocupy.ai.inference import inference_pipeline + __author__ = "Viktor Nikitin" __copyright__ = "Copyright (c) 2022, UChicago Argonne, LLC." __docformat__ = 'restructuredtext en' @@ -82,6 +84,9 @@ def find_center(self): center = self.find_center_vo() return (center*2**args.binning).astype('float32') + def find_center_ai(self, args, img_cache, center_of_rotation_cache, out_dir): + return inference_pipeline(args, img_cache, center_of_rotation_cache, out_dir) + def find_center_sift(self): pairs = literal_eval(args.rotation_axis_pairs) diff --git a/src/tomocupy/rec.py b/src/tomocupy/rec.py index 75e2cac..608627b 100644 --- a/src/tomocupy/rec.py +++ b/src/tomocupy/rec.py @@ -66,7 +66,7 @@ class GPURec(): The implemented reconstruction method is Fourier-based with exponential functions for interpoaltion in the frequency domain (implemented with CUDA C). ''' - def __init__(self, cl_reader, cl_writer): + def __init__(self, cl_reader, cl_writer, cache_to_infer=False): # Set ^C, ^Z interrupt to abort and deallocate memory on GPU signal.signal(signal.SIGINT, utils.signal_handler) @@ -109,6 +109,7 @@ def __init__(self, cl_reader, cl_writer): # additional refs self.cl_reader = cl_reader self.cl_writer = cl_writer + self.cache_to_infer = cache_to_infer def recon_all(self): """Reconstruction of data from an h5file by splitting into sinogram chunks""" @@ -243,6 +244,10 @@ def recon_try(self): sht = cp.zeros(ncz, dtype='float32') # Conveyor for data cpu-gpu copy and reconstruction + if self.cache_to_infer: + img_cache = [] + center_of_rotation_cache = [] + id_slice_cache = [] for k in range(nschunk+2): utils.printProgressBar( k, nschunk+1, self.data_queue.qsize(), length=40) @@ -267,9 +272,19 @@ def recon_try(self): for kk in range(lschunk[k-2]): self.write_threads[ithread].run(self.cl_writer.write_data_try, ( rec_pinned[ithread, kk], params.save_centers[(k-2)*ncz+kk], id_slice)) + if self.cache_to_infer: + img_cache.append(np.copy(rec_pinned[ithread, kk:kk+1])) + center_of_rotation_cache.append(params.save_centers[(k-2)*ncz+kk]) + id_slice_cache.append(id_slice) self.stream1.synchronize() self.stream2.synchronize() for t in self.write_threads: t.join() + + if self.cache_to_infer: + img_cache = np.concatenate(img_cache,axis=0) + center_of_rotation_cache = np.array(center_of_rotation_cache) + id_slice_cache = np.array(id_slice_cache) + return img_cache, center_of_rotation_cache,id_slice_cache diff --git a/src/tomocupy/rec_steps.py b/src/tomocupy/rec_steps.py index bb8e182..d3c03d8 100644 --- a/src/tomocupy/rec_steps.py +++ b/src/tomocupy/rec_steps.py @@ -70,7 +70,7 @@ class GPURecSteps(): 2) Direct discretization of the backprojection intergral """ - def __init__(self, cl_reader, cl_writer): + def __init__(self, cl_reader, cl_writer, cache_to_infer=False): # Set ^C interrupt to abort and deallocate memory on GPU signal.signal(signal.SIGINT, utils.signal_handler) signal.signal(signal.SIGTERM, utils.signal_handler) @@ -109,7 +109,9 @@ def __init__(self, cl_reader, cl_writer): self.cl_backproj = backproj_lamfourier_parallel.BackprojLamFourierParallel( cl_writer) else: - self.cl_backproj = backproj_parallel.BackprojParallel(cl_writer) + self.cl_backproj = backproj_parallel.BackprojParallel(cl_writer,cache_to_infer=cache_to_infer) + + self.cache_to_infer = cache_to_infer def recon_steps_all(self): """GPU reconstruction by loading a full dataset in memory and processing by steps, with reading the whole data to memory """ @@ -122,7 +124,10 @@ def recon_steps_all(self): log.info('Processing by chunks in angles.') data = self.proc_proj_parallel(data) log.info('Filtered backprojection and writing by chunks.') - self.cl_backproj.rec_fun(data) + if self.cache_to_infer: + return self.cl_backproj.rec_fun(data) + else: + self.cl_backproj.rec_fun(data) def proc_sino_parallel(self, data, dark, flat): """Data processing by splitting into sinogram chunks""" diff --git a/src/tomocupy/reconstruction/backproj_parallel.py b/src/tomocupy/reconstruction/backproj_parallel.py index f94249e..e755f01 100644 --- a/src/tomocupy/reconstruction/backproj_parallel.py +++ b/src/tomocupy/reconstruction/backproj_parallel.py @@ -55,7 +55,7 @@ class BackprojParallel(): - def __init__(self, cl_writer): + def __init__(self, cl_writer, cache_to_infer = False): # init tomo functions self.cl_backproj_func = backproj_functions.BackprojFunctions() @@ -91,6 +91,8 @@ def __init__(self, cl_writer): self.rec_fun = rec_fun self.cl_writer = cl_writer + self.cache_to_infer = cache_to_infer + def recon_sino_proj_parallel(self, data): """Reconstruction by splitting into sinogram and projectionchunks""" @@ -194,6 +196,10 @@ def recon_try_sino_proj_parallel(self, data): rec_gpu = cp.zeros([2, *self.shape_recon_chunk], dtype=params.dtype) # Conveyor for data cpu-gpu copy and reconstruction + if self.cache_to_infer: + img_cache = [] + center_of_rotation_cache = [] + id_slice_cache = [] for id_slice in params.id_slices: log.info(f'Processing slice {id_slice}') for ks in range(nschunk+2): @@ -234,11 +240,20 @@ def recon_try_sino_proj_parallel(self, data): for kk in range(lschunk[ks-2]): self.write_threads[ithread].run(self.cl_writer.write_data_try, ( rec_pinned[ithread, kk], params.save_centers[(ks-2)*ncz+kk], id_slice)) - + if self.cache_to_infer: + img_cache.append(np.copy(rec_pinned[ithread, kk:kk+1])) + center_of_rotation_cache.append(params.save_centers[(ks-2)*ncz+kk]) + id_slice_cache.append(id_slice) self.stream1.synchronize() self.stream2.synchronize() for t in self.write_threads: t.join() + + if self.cache_to_infer: + img_cache = np.concatenate(img_cache,axis=0) + center_of_rotation_cache = np.array(center_of_rotation_cache) + id_slice_cache = np.array(id_slice_cache) + return img_cache, center_of_rotation_cache,id_slice_cache def recon_try_lamino_sino_proj_parallel(self, data): """Reconstruction of 1 slice with different lamino angles by splitting data into sinogram and projection chunks""" @@ -265,6 +280,10 @@ def recon_try_lamino_sino_proj_parallel(self, data): # gpu memory for reconstrution rec_gpu = cp.zeros([2, *self.shape_recon_chunk], dtype=params.dtype) + if self.cache_to_infer: + img_cache = [] + center_of_rotation_cache = [] + id_slice_cache = [] for id_slice in params.id_slices: log.info(f'Processing slice {id_slice}') # Conveyor for data cpu-gpu copy and reconstruction @@ -307,10 +326,20 @@ def recon_try_lamino_sino_proj_parallel(self, data): for kk in range(lschunk[ks-2]): self.write_threads[ithread].run(self.cl_writer.write_data_try, ( rec_pinned[ithread, kk], params.save_centers[(ks-2)*ncz+kk], id_slice)) + if self.cache_to_infer: + img_cache.append(np.copy(rec_pinned[ithread, kk:kk+1])) + center_of_rotation_cache.append(params.save_centers[(ks-2)*ncz+kk]) + id_slice_cache.append(id_slice) self.stream1.synchronize() self.stream2.synchronize() for t in self.write_threads: t.join() + + if self.cache_to_infer: + img_cache = np.concatenate(img_cache,axis=0) + center_of_rotation_cache = np.array(center_of_rotation_cache) + id_slice_cache = np.array(id_slice_cache) + return img_cache, center_of_rotation_cache,id_slice_cache def recon_sino_parallel(self, data): """Reconstruction by splitting into sinogram chunks""" @@ -391,6 +420,10 @@ def recon_try_sino_parallel(self, data): rec_gpu = cp.zeros([2, *self.shape_recon_chunk], dtype=dtype) # Conveyor for data cpu-gpu copy and reconstruction + if self.cache_to_infer: + img_cache = [] + center_of_rotation_cache = [] + id_slice_cache = [] for k in range(nschunk+2): utils.printProgressBar( k, nschunk+1, nschunk-k+1, length=40) @@ -414,9 +447,19 @@ def recon_try_sino_parallel(self, data): for kk in range(lschunk[k-2]): self.write_threads[ithread].run(self.cl_writer.write_data_try, ( rec_pinned[ithread, kk], params.save_centers[(k-2)*ncz+kk], id_slice)) + if self.cache_to_infer: + img_cache.append(np.copy(rec_pinned[ithread, kk:kk+1])) + center_of_rotation_cache.append(params.save_centers[(k-2)*ncz+kk]) + id_slice_cache.append(id_slice) self.stream1.synchronize() self.stream2.synchronize() for t in self.write_threads: t.join() + + if self.cache_to_infer: + img_cache = np.concatenate(img_cache,axis=0) + center_of_rotation_cache = np.array(center_of_rotation_cache) + id_slice_cache = np.array(id_slice_cache) + return img_cache, center_of_rotation_cache,id_slice_cache From c5bd5dbe6628d611e28419e7b6f3a490433efa0f Mon Sep 17 00:00:00 2001 From: Songyuan Tang Date: Sat, 14 Mar 2026 22:23:02 -0500 Subject: [PATCH 2/5] conda environment listed --- tomocupy_cor.yml | 269 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 tomocupy_cor.yml diff --git a/tomocupy_cor.yml b/tomocupy_cor.yml new file mode 100644 index 0000000..c2fa9b3 --- /dev/null +++ b/tomocupy_cor.yml @@ -0,0 +1,269 @@ +name: tomocupy +channels: + - conda-forge + - defaults +dependencies: + - _openmp_mutex=4.5=20_gnu + - alsa-lib=1.2.15.3=hb03c661_0 + - aom=3.9.1=hac33072_0 + - attr=2.5.2=h39aace5_0 + - blosc=1.21.6=he440d0b_1 + - brunsli=0.1=he3183e4_1 + - bzip2=1.0.8=hda65f42_9 + - c-ares=1.34.6=hb03c661_0 + - c-blosc2=2.19.1=h4cfbee9_0 + - ca-certificates=2026.2.25=hbd8a1cb_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cairo=1.18.4=h3394656_0 + - charls=2.4.3=hecca717_0 + - cmake=4.0.2=h74e3db0_0 + - cuda-cudart=12.9.79=h5888daf_0 + - cuda-cudart_linux-64=12.9.79=h3f2d84a_0 + - cuda-nvrtc=12.9.86=hecca717_1 + - cuda-nvtx=12.9.79=hecca717_1 + - cuda-version=12.9=h4f385c5_3 + - cupy=12.0.0=py310he66c036_4 + - cyrus-sasl=2.1.28=hd9c7081_0 + - dav1d=1.2.1=hd590300_0 + - dbus=1.16.2=h24cb091_1 + - distro=1.9.0=pyhd8ed1ab_1 + - double-conversion=3.3.1=h5888daf_0 + - fastrlock=0.8.3=py310h25320af_2 + - ffmpeg=7.1.1=gpl_h127656b_906 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_3 + - fontconfig=2.17.1=h27c8c51_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=hc364b38_1 + - freeglut=3.2.2=ha6d2627_3 + - freetype=2.14.2=ha770c72_0 + - fribidi=1.0.16=hb03c661_0 + - gdk-pixbuf=2.44.4=h2b0a6b4_0 + - giflib=5.2.2=hd590300_0 + - gmp=6.3.0=hac33072_2 + - graphite2=1.3.14=hecca717_2 + - h5py=3.14.0=nompi_py310h4aa865e_101 + - harfbuzz=12.2.0=h15599e2_0 + - hdf5=1.14.6=nompi_h19486de_106 + - icu=75.1=he02047a_0 + - imagecodecs=2025.3.30=py310h4eb8eaf_2 + - imath=3.1.12=h7955e40_0 + - jasper=4.2.9=he3c4edf_0 + - jxrlib=1.1=hd590300_3 + - keyutils=1.6.3=hb9d3cd8_0 + - krb5=1.21.3=h659f571_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.18=h0c24ade_0 + - ld_impl_linux-64=2.45.1=default_hbd61a6d_101 + - lerc=4.1.0=hdb68285_0 + - level-zero=1.28.2=hb700be7_0 + - libabseil=20250127.1=cxx17_hbbce691_0 + - libaec=1.1.5=h088129d_0 + - libasprintf=0.25.1=h3f43e3d_1 + - libass=0.17.3=hba53ac1_1 + - libavif16=1.3.0=h766b0b6_0 + - libblas=3.11.0=5_h4a7cf45_openblas + - libbrotlicommon=1.1.0=hb03c661_4 + - libbrotlidec=1.1.0=hb03c661_4 + - libbrotlienc=1.1.0=hb03c661_4 + - libcap=2.77=h3ff7636_0 + - libcblas=3.11.0=5_h0358290_openblas + - libclang-cpp21.1=21.1.8=default_h99862b1_3 + - libclang13=21.1.8=default_h746c552_3 + - libcublas=12.9.1.4=h676940d_1 + - libcufft=11.4.1.4=hecca717_1 + - libcups=2.3.3=hb8b1518_5 + - libcurand=10.3.10.19=h676940d_1 + - libcurl=8.18.0=h4e3cde8_0 + - libcusolver=11.7.5.82=h676940d_2 + - libcusparse=12.5.10.65=hecca717_2 + - libdeflate=1.24=h86f0d12_0 + - libdrm=2.4.125=hb03c661_1 + - libedit=3.1.20250104=pl5321h7949ede_0 + - libegl=1.7.0=ha4b6fd6_2 + - libev=4.33=hd590300_2 + - libexpat=2.7.4=hecca717_0 + - libffi=3.5.2=h3435931_0 + - libflac=1.5.0=he200343_1 + - libfreetype=2.14.2=ha770c72_0 + - libfreetype6=2.14.2=h73754d4_0 + - libgcc=15.2.0=he0feb66_18 + - libgcc-ng=15.2.0=h69a702a_18 + - libgettextpo=0.25.1=h3f43e3d_1 + - libgfortran=15.2.0=h69a702a_18 + - libgfortran5=15.2.0=h68bc16d_18 + - libgl=1.7.0=ha4b6fd6_2 + - libglib=2.86.2=h32235b2_0 + - libglu=9.0.3=h5888daf_1 + - libglvnd=1.7.0=ha4b6fd6_2 + - libglx=1.7.0=ha4b6fd6_2 + - libgomp=15.2.0=he0feb66_18 + - libhwloc=2.12.1=default_h3d81e11_1000 + - libhwy=1.3.0=h4c17acf_1 + - libiconv=1.18=h3b78370_2 + - libjpeg-turbo=3.1.2=hb03c661_0 + - libjxl=0.11.1=h6cb5226_4 + - liblapack=3.11.0=5_h47877c9_openblas + - liblapacke=3.11.0=5_h6ae95b6_openblas + - libllvm21=21.1.8=h5ad376a_0 + - liblzma=5.8.2=hb03c661_0 + - libnghttp2=1.67.0=had1ee68_0 + - libnsl=2.0.1=hb9d3cd8_1 + - libntlm=1.8=hb9d3cd8_0 + - libnvjitlink=12.9.86=hecca717_2 + - libogg=1.3.5=hd0c01bc_1 + - libopenblas=0.3.30=pthreads_h94d23a6_4 + - libopencv=4.11.0=qt6_py310h13b287b_609 + - libopengl=1.7.0=ha4b6fd6_2 + - libopenvino=2025.0.0=hdc3f47d_3 + - libopenvino-auto-batch-plugin=2025.0.0=h4d9b6c2_3 + - libopenvino-auto-plugin=2025.0.0=h4d9b6c2_3 + - libopenvino-hetero-plugin=2025.0.0=h981d57b_3 + - libopenvino-intel-cpu-plugin=2025.0.0=hdc3f47d_3 + - libopenvino-intel-gpu-plugin=2025.0.0=hdc3f47d_3 + - libopenvino-intel-npu-plugin=2025.0.0=hdc3f47d_3 + - libopenvino-ir-frontend=2025.0.0=h981d57b_3 + - libopenvino-onnx-frontend=2025.0.0=h0e684df_3 + - libopenvino-paddle-frontend=2025.0.0=h0e684df_3 + - libopenvino-pytorch-frontend=2025.0.0=h5888daf_3 + - libopenvino-tensorflow-frontend=2025.0.0=h684f15b_3 + - libopenvino-tensorflow-lite-frontend=2025.0.0=h5888daf_3 + - libopus=1.6.1=h280c20c_0 + - libpciaccess=0.18=hb9d3cd8_0 + - libpng=1.6.55=h421ea60_0 + - libpq=17.7=h5c52fec_1 + - libprotobuf=5.29.3=h7460b1f_3 + - librsvg=2.58.4=h49af25d_2 + - libsndfile=1.2.2=hc7d488a_2 + - libsqlite=3.52.0=h0c1763c_0 + - libssh2=1.11.1=hcf80075_0 + - libstdcxx=15.2.0=h934c35e_18 + - libstdcxx-ng=15.2.0=hdf11a46_18 + - libsystemd0=259.5=h6569c3e_0 + - libtiff=4.7.1=h8261f1e_0 + - libudev1=259.5=h6569c3e_0 + - libunwind=1.8.3=h65a8314_0 + - liburing=2.12=hb700be7_0 + - libusb=1.0.29=h73b1eb8_0 + - libuuid=2.41.3=h5347b49_0 + - libuv=1.51.0=hb03c661_1 + - libva=2.23.0=he1eb515_0 + - libvorbis=1.3.7=h54a6638_2 + - libvpx=1.14.1=hac33072_0 + - libvulkan-loader=1.4.341.0=h5279c79_0 + - libwebp-base=1.6.0=hd42ef1d_0 + - libxcb=1.17.0=h8a09558_0 + - libxcrypt=4.4.36=hd590300_1 + - libxkbcommon=1.11.0=he8b52b9_0 + - libxml2=2.13.9=h04c0eec_0 + - libzlib=1.3.1=hb9d3cd8_2 + - libzopfli=1.0.3=h9c3ff4c_0 + - lz4-c=1.10.0=h5888daf_1 + - mpg123=1.32.9=hc50e24c_0 + - ncurses=6.5=h2d0b736_3 + - nomkl=1.0=h5ca1d4c_0 + - numexpr=2.10.2=py310hdb6e06b_100 + - numpy=1.26.4=py310hb13e2d6_0 + - ocl-icd=2.3.3=hb9d3cd8_0 + - opencl-headers=2025.06.13=h5888daf_0 + - opencv=4.11.0=qt6_py310h630078d_609 + - openexr=3.3.5=h09fa569_0 + - openh264=2.6.0=hc22cd8d_0 + - openjpeg=2.5.4=h55fea9a_0 + - openldap=2.6.10=he970967_0 + - openssl=3.6.1=h35e630c_1 + - packaging=26.0=pyhcf101f3_0 + - pango=1.56.4=hadf4263_0 + - pcre2=10.46=h1321c63_0 + - pip=26.0.1=pyh8b19718_0 + - pixman=0.46.4=h54a6638_1 + - pthread-stubs=0.4=hb9d3cd8_1002 + - pugixml=1.15=h3f63f65_0 + - pulseaudio-client=17.0=h9a6aba3_3 + - py-opencv=4.11.0=qt6_py310h77b9700_609 + - python=3.10.20=h3c07f61_0_cpython + - python_abi=3.10=8_cp310 + - pywavelets=1.8.0=py310hf462985_0 + - qt6-main=6.9.2=h5bd77bc_1 + - rav1e=0.7.1=h8fae777_3 + - readline=8.3=h853b02a_0 + - rhash=1.4.6=hb9d3cd8_1 + - scikit-build=0.18.1=pyhae55e72_2 + - sdl2=2.32.56=h54a6638_0 + - sdl3=3.2.24=h68140b3_0 + - setuptools=82.0.1=pyh332efcf_0 + - snappy=1.2.2=h03e3b7b_1 + - svt-av1=3.0.2=h5888daf_0 + - swig=4.3.1=hf1419ba_4 + - tbb=2022.3.0=h8d10470_1 + - tifffile=2025.5.10=pyhd8ed1ab_0 + - tk=8.6.13=noxft_h366c992_103 + - tomli=2.4.0=pyhcf101f3_0 + - typing-extensions=4.15.0=h396c80c_0 + - typing_extensions=4.15.0=pyhcf101f3_0 + - tzdata=2025c=hc9c84f9_1 + - wayland=1.24.0=hd6090a7_1 + - wayland-protocols=1.47=hd8ed1ab_0 + - wheel=0.46.3=pyhd8ed1ab_0 + - x264=1!164.3095=h166bdaf_2 + - x265=3.5=h924138e_3 + - xcb-util=0.4.1=h4f16b4b_2 + - xcb-util-cursor=0.1.6=hb03c661_0 + - xcb-util-image=0.4.0=hb711507_2 + - xcb-util-keysyms=0.4.1=hb711507_0 + - xcb-util-renderutil=0.3.10=hb711507_0 + - xcb-util-wm=0.4.2=hb711507_0 + - xkeyboard-config=2.47=hb03c661_0 + - xorg-libice=1.1.2=hb9d3cd8_0 + - xorg-libsm=1.2.6=he73a12e_0 + - xorg-libx11=1.8.13=he1eb515_0 + - xorg-libxau=1.0.12=hb03c661_1 + - xorg-libxcomposite=0.4.7=hb03c661_0 + - xorg-libxcursor=1.2.3=hb9d3cd8_0 + - xorg-libxdamage=1.1.6=hb9d3cd8_0 + - xorg-libxdmcp=1.1.5=hb03c661_1 + - xorg-libxext=1.3.7=hb03c661_0 + - xorg-libxfixes=6.0.2=hb03c661_0 + - xorg-libxi=1.8.2=hb9d3cd8_0 + - xorg-libxrandr=1.5.5=hb03c661_0 + - xorg-libxrender=0.9.12=hb9d3cd8_0 + - xorg-libxscrnsaver=1.2.4=hb9d3cd8_0 + - xorg-libxtst=1.2.5=hb9d3cd8_3 + - xorg-libxxf86vm=1.1.7=hb03c661_0 + - zfp=1.0.1=h909a3a2_5 + - zlib=1.3.1=hb9d3cd8_2 + - zlib-ng=2.2.5=hde8ca8f_1 + - zstd=1.5.7=hb78ec9c_6 + - pip: + - einops==0.8.2 + - filelock==3.20.0 + - fsspec==2025.12.0 + - jinja2==3.1.6 + - markupsafe==3.0.2 + - mpmath==1.3.0 + - networkx==3.4.2 + - nvidia-cublas-cu12==12.8.4.1 + - nvidia-cuda-cupti-cu12==12.8.90 + - nvidia-cuda-nvrtc-cu12==12.8.93 + - nvidia-cuda-runtime-cu12==12.8.90 + - nvidia-cudnn-cu12==9.10.2.21 + - nvidia-cufft-cu12==11.3.3.83 + - nvidia-cufile-cu12==1.13.1.3 + - nvidia-curand-cu12==10.3.9.90 + - nvidia-cusolver-cu12==11.7.3.90 + - nvidia-cusparse-cu12==12.5.8.93 + - nvidia-cusparselt-cu12==0.7.1 + - nvidia-nccl-cu12==2.27.3 + - nvidia-nvjitlink-cu12==12.8.93 + - nvidia-nvtx-cu12==12.8.90 + - pillow==12.0.0 + - sympy==1.14.0 + - tomocupy==1.1.0 + - torch==2.8.0+cu128 + - torchaudio==2.8.0+cu128 + - torchvision==0.23.0+cu128 + - triton==3.4.0 +prefix: /home/beams/TANGS/conda/anaconda3/envs/tomocupy From 245f438b6a36c4bc6bbc3310054e32f7fbf63e04 Mon Sep 17 00:00:00 2001 From: Songyuan Tang Date: Sun, 15 Mar 2026 00:46:45 -0500 Subject: [PATCH 3/5] updated config --- src/tomocupy/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tomocupy/config.py b/src/tomocupy/config.py index a942b6c..a718e29 100644 --- a/src/tomocupy/config.py +++ b/src/tomocupy/config.py @@ -600,7 +600,7 @@ def default_parameter(func, param): 'help': "Size of each square window to aggregate try recon image features" }, 'infer-model-path': { - 'default': '/home/beams/TANGS/conda/tomocor_models/datav2_518_full_finetune/epoch_10.pt', + 'default': 'none', 'type': str, 'help': "Path to the trained model weights" }, From af5541d2aef3d98fc7cf390c236d21b34c0ce9e6 Mon Sep 17 00:00:00 2001 From: Viktor Nikitin Date: Thu, 19 Mar 2026 16:20:05 -0500 Subject: [PATCH 4/5] refactor AI center detection; add install docs and model-path guard - Extract _find_center, _find_center_ai, _check_use_ai helpers in __main__.py - _check_use_ai gracefully falls back to vo method if torch is not installed - Add ValueError in inference.py when --infer-model-path is not set (was silently passing 'none' to torch.load) - Lazy-import inference_pipeline in find_center.py to avoid hard torch dependency - Add AI dependencies section to docs/source/install.rst --- docs/source/install.rst | 12 +++++ src/tomocupy/__main__.py | 100 ++++++++++++++++++----------------- src/tomocupy/ai/inference.py | 2 + src/tomocupy/find_center.py | 2 +- 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/docs/source/install.rst b/docs/source/install.rst index 2e21eb8..6c87a0a 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -113,6 +113,18 @@ Additional instructions for Windows .. note:: It is better to have only one version of VS and one version of CUDA toolkit on your system to avoid problems with environmental variables +================================ +AI-based Center of Rotation +================================ + +To use the AI-based center of rotation detection (``--rotation-axis-method ai``), install the following additional dependencies:: + + (tomocupy)$ conda install -c conda-forge pytorch pillow einops + +Then run reconstruction with:: + + (tomocupy)$ tomocupy recon --file-name --rotation-axis-method ai --rotation-axis-auto auto --infer-model-path + ========== Unit tests ========== diff --git a/src/tomocupy/__main__.py b/src/tomocupy/__main__.py index 442d59c..04731c4 100644 --- a/src/tomocupy/__main__.py +++ b/src/tomocupy/__main__.py @@ -41,7 +41,6 @@ import sys import time import argparse -import time import os from pathlib import Path from datetime import datetime @@ -74,80 +73,83 @@ def run_status(args): config.log_values(args) +def _find_center(cl_reader): + clrotthandle = FindCenter(cl_reader) + args.rotation_axis = clrotthandle.find_center() + params.center = args.rotation_axis + params.centeri = args.rotation_axis + log.warning(f'set rotation axis {args.rotation_axis}') + + +def _find_center_ai(cl_reader, img_cache, center_of_rotation_cache): + clrotthandle = FindCenter(cl_reader) + args.rotation_axis = clrotthandle.find_center_ai(args, img_cache, center_of_rotation_cache, params.fnameout[:-6]) + params.center = args.rotation_axis + log.warning(f'set rotation axis {args.rotation_axis}') + + +def _check_use_ai(): + if args.rotation_axis_auto != 'auto' or args.rotation_axis_method != 'ai': + return False + try: + import torch + return True + except ImportError: + log.warning('torch is not installed — skipping AI center search, falling back to vo method') + args.rotation_axis_method = 'vo' + return False + + def run_rec(args, cl_reader, cl_writer): - file_name = Path(args.file_name) - if not file_name.is_file(): + if not Path(args.file_name).is_file(): log.error("File Name does not exist: %s" % args.file_name) exit() t = time.time() - # set the default parameters args.retrieve_phase_method = 'none' args.rotate_proj_angle = 0 args.lamino_angle = 0 - # rotation axis search - if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method != 'ai'): - clrotthandle = FindCenter(cl_reader) - args.rotation_axis = clrotthandle.find_center() - params.center = args.rotation_axis - log.warning(f'set rotaion axis {args.rotation_axis}') - - # create reconstruction object and run reconstruction - if (args.reconstruction_type == 'try') and (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): - cache_to_infer = True - else: - cache_to_infer = False + + use_ai = _check_use_ai() + if args.rotation_axis_auto == 'auto' and not use_ai: + _find_center(cl_reader) + + cache_to_infer = args.reconstruction_type == 'try' and use_ai clpthandle = GPURec(cl_reader, cl_writer, cache_to_infer=cache_to_infer) if args.reconstruction_type == 'full': - clpthandle.recon_all() - if args.reconstruction_type == 'try': - if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): - img_cache, center_of_rotation_cache, id_slice_cache = clpthandle.recon_try() - clrotthandle = FindCenter(cl_reader) - - args.rotation_axis = clrotthandle.find_center_ai(args, img_cache, center_of_rotation_cache, params.fnameout[:-6]) - params.center = args.rotation_axis - log.warning(f'set rotaion axis {args.rotation_axis}') + elif args.reconstruction_type == 'try': + if use_ai: + img_cache, center_of_rotation_cache, _ = clpthandle.recon_try() + _find_center_ai(cl_reader, img_cache, center_of_rotation_cache) else: clpthandle.recon_try() - - rec_time = (time.time()-t) - - log.warning(f'Reconstruction time {rec_time:.1e}s') + log.warning(f'Reconstruction time {time.time()-t:.1e}s') def run_recsteps(args, cl_reader, cl_writer): - file_name = Path(args.file_name) - if not file_name.is_file(): + if not Path(args.file_name).is_file(): log.error("File Name does not exist: %s" % args.file_name) exit() + t = time.time() - if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method != 'ai'): - clrotthandle = FindCenter(cl_reader) - args.rotation_axis = clrotthandle.find_center() - params.center = args.rotation_axis - log.warning(f'set rotaion axis {args.rotation_axis}') + use_ai = _check_use_ai() + if args.rotation_axis_auto == 'auto' and not use_ai: + _find_center(cl_reader) - if (args.reconstruction_type == 'try') and (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): - cache_to_infer = True - else: - cache_to_infer = False - clpthandle = GPURecSteps(cl_reader, cl_writer,cache_to_infer=cache_to_infer) - # does all preprocessing for both full and try reconstructions - if (args.rotation_axis_auto == 'auto') and (args.rotation_axis_method == 'ai'): - img_cache, center_of_rotation_cache, id_slice_cache = clpthandle.recon_steps_all() - clrotthandle = FindCenter(cl_reader) - args.rotation_axis = clrotthandle.find_center_ai(args, img_cache, center_of_rotation_cache, params.fnameout[:-6]) - params.center = args.rotation_axis - log.warning(f'set rotaion axis {args.rotation_axis}') + cache_to_infer = use_ai + clpthandle = GPURecSteps(cl_reader, cl_writer, cache_to_infer=cache_to_infer) + + if use_ai: + img_cache, center_of_rotation_cache, _ = clpthandle.recon_steps_all() + _find_center_ai(cl_reader, img_cache, center_of_rotation_cache) else: clpthandle.recon_steps_all() - log.warning(f'Reconstruction time {(time.time()-t):.01f}s') + log.warning(f'Reconstruction time {time.time()-t:.1f}s') def main(): diff --git a/src/tomocupy/ai/inference.py b/src/tomocupy/ai/inference.py index d75e5b3..c7eb4f9 100644 --- a/src/tomocupy/ai/inference.py +++ b/src/tomocupy/ai/inference.py @@ -36,6 +36,8 @@ def inference_pipeline(args, img_cache_original, center_of_rotation_cache, out_d assert isinstance(szs,list) seed_number = args.infer_seed_number model_path = args.infer_model_path + if model_path == 'none': + raise ValueError("--infer-model-path must be set when using --rotation-axis-method ai") if len(nums_windows)>1: multi_instances = True elif len(nums_windows)==1 and nums_windows[0]>1: diff --git a/src/tomocupy/find_center.py b/src/tomocupy/find_center.py index 7810893..656977f 100644 --- a/src/tomocupy/find_center.py +++ b/src/tomocupy/find_center.py @@ -51,7 +51,6 @@ import signal import cv2 -from tomocupy.ai.inference import inference_pipeline __author__ = "Viktor Nikitin" __copyright__ = "Copyright (c) 2022, UChicago Argonne, LLC." @@ -85,6 +84,7 @@ def find_center(self): return (center*2**args.binning).astype('float32') def find_center_ai(self, args, img_cache, center_of_rotation_cache, out_dir): + from tomocupy.ai.inference import inference_pipeline return inference_pipeline(args, img_cache, center_of_rotation_cache, out_dir) def find_center_sift(self): From afe2ae87e84636506b5cc9779ba9c1fe90e0bb4e Mon Sep 17 00:00:00 2001 From: Songyuan Tang Date: Fri, 20 Mar 2026 17:06:57 -0500 Subject: [PATCH 5/5] included a pointer to the model file when --infer-model-path is not set --- src/tomocupy/ai/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tomocupy/ai/inference.py b/src/tomocupy/ai/inference.py index c7eb4f9..d831489 100644 --- a/src/tomocupy/ai/inference.py +++ b/src/tomocupy/ai/inference.py @@ -37,7 +37,7 @@ def inference_pipeline(args, img_cache_original, center_of_rotation_cache, out_d seed_number = args.infer_seed_number model_path = args.infer_model_path if model_path == 'none': - raise ValueError("--infer-model-path must be set when using --rotation-axis-method ai") + raise ValueError("--infer-model-path must be set when using --rotation-axis-method ai\n The model can be downloaded from: https://anl.box.com/s/4o8qcig6pl9k8p7x4z3qqbrpgnjipolq.") if len(nums_windows)>1: multi_instances = True elif len(nums_windows)==1 and nums_windows[0]>1: