From 73e596632e269ba0872b7019501e60c60a1f7a90 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 08:01:56 +0000 Subject: [PATCH 1/8] feat(opennmt): Use inference_mode instead of no_grad --- retrochimera/opennmt/decode/translator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/retrochimera/opennmt/decode/translator.py b/retrochimera/opennmt/decode/translator.py index 3d55a66..9ec992c 100644 --- a/retrochimera/opennmt/decode/translator.py +++ b/retrochimera/opennmt/decode/translator.py @@ -13,6 +13,7 @@ 1. Simplified the structure by removing the `build_translator` function, and merged the `Inference` class into the `Translator` class. 2. Refactored the `_run_encoder`, `_decode_and_generate`, and `_translate_batch_with_strategy` methods to align with the `retrochimera.models.smiles_transformer.SmilesTransformerModel` class. 3. Introduced the `customised_beam_search` attribute and corresponding logic to the `Translator` class, enabling optimized beam search for retrosynthesis prediction. +4. Switched `torch.no_grad()` to `torch.inference_mode()` for better performance. """ from typing import Any, Optional @@ -114,7 +115,7 @@ def translate_batch(self, batch: dict[str, Any], attn_debug: bool = False) -> di - src_lengths (torch.Tensor): shape of src_lengths: (batch_size,) - batch_size: int """ - with torch.no_grad(): + with torch.inference_mode(): # TODO: support these blacklisted features decode_strategy = BeamSearch( pad=self._tgt_pad_idx, From 215a1f826f544643c8fad2bf9695bd4480d9e171 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 08:09:29 +0000 Subject: [PATCH 2/8] feat(opennmt): Use torch.where for masked selection --- retrochimera/opennmt/decode/translator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/retrochimera/opennmt/decode/translator.py b/retrochimera/opennmt/decode/translator.py index 9ec992c..3ad33cb 100644 --- a/retrochimera/opennmt/decode/translator.py +++ b/retrochimera/opennmt/decode/translator.py @@ -219,8 +219,8 @@ def _translate_batch_with_strategy( decoder_input.squeeze() == self._tgt_eos_idx ) # shape: (batch_size * beam_size,) log_prob_mask = is_end_token.unsqueeze(1) # shape: (batch_size * beam_size, 1) - log_probs = (log_prob_mask * complete_seq_log_prob) + ( - ~log_prob_mask * log_probs + log_probs = torch.where( + log_prob_mask, complete_seq_log_prob, log_probs ) # shape: (batch_size * beam_size, vocab_size) assert log_probs.dim() == 2, f"Expected 2D tensor, got {log_probs.dim()}" From 4c89d73f31316f4ad1c3e5823de81aa7fbd03006 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 08:18:55 +0000 Subject: [PATCH 3/8] feat(opennmt): Compute complete_seq_log_prob once outside of the loop --- retrochimera/opennmt/decode/translator.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/retrochimera/opennmt/decode/translator.py b/retrochimera/opennmt/decode/translator.py index 3ad33cb..3ac864f 100644 --- a/retrochimera/opennmt/decode/translator.py +++ b/retrochimera/opennmt/decode/translator.py @@ -14,6 +14,7 @@ 2. Refactored the `_run_encoder`, `_decode_and_generate`, and `_translate_batch_with_strategy` methods to align with the `retrochimera.models.smiles_transformer.SmilesTransformerModel` class. 3. Introduced the `customised_beam_search` attribute and corresponding logic to the `Translator` class, enabling optimized beam search for retrosynthesis prediction. 4. Switched `torch.no_grad()` to `torch.inference_mode()` for better performance. +5. Optimized code for more efficient tensor operations and memory usage. """ from typing import Any, Optional @@ -187,6 +188,14 @@ def _translate_batch_with_strategy( if fn_map_state is not None: self.model.decoder.map_state(fn_map_state, only_map_src=True) + complete_seq_log_prob = None + if self.customised_beam_search: + vocab_size = self._tgt_vocab_len + complete_seq_log_prob = torch.full( + (1, vocab_size), -1e5, device=memory_bank.device, dtype=torch.float32 + ) + complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0 + # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): decoder_input = decode_strategy.current_predictions.view( @@ -206,14 +215,6 @@ def _translate_batch_with_strategy( if self.customised_beam_search: # modify the log_probs for finished sentences. - _, vocab_size = tuple(log_probs.shape) - bad_token_log_prob = -1e5 - - complete_seq_log_prob = (torch.ones((1, vocab_size)) * bad_token_log_prob).to( - log_probs.device - ) # shape: (1, vocab_size) - complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0 # shape: (1, vocab_size) - # Use this vector in the output for sequences which are complete. is_end_token = ( decoder_input.squeeze() == self._tgt_eos_idx From b62f4c091998b87ebe16e0109682b6a01210b241 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 08:39:49 +0000 Subject: [PATCH 4/8] feat(opennmt): Pre-transpose memory bank and precompute the mask --- retrochimera/opennmt/decode/translator.py | 36 +++++++++++++---------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/retrochimera/opennmt/decode/translator.py b/retrochimera/opennmt/decode/translator.py index 3ac864f..c7bd656 100644 --- a/retrochimera/opennmt/decode/translator.py +++ b/retrochimera/opennmt/decode/translator.py @@ -188,11 +188,18 @@ def _translate_batch_with_strategy( if fn_map_state is not None: self.model.decoder.map_state(fn_map_state, only_map_src=True) + memory_bank_bt = memory_bank.transpose(0, 1).contiguous() + src_pad_len = memory_bank_bt.size(1) + memory_padding_mask = ( + torch.arange(0, src_pad_len, device=memory_bank_bt.device) + >= memory_lengths.unsqueeze(1) + ) + complete_seq_log_prob = None if self.customised_beam_search: vocab_size = self._tgt_vocab_len complete_seq_log_prob = torch.full( - (1, vocab_size), -1e5, device=memory_bank.device, dtype=torch.float32 + (1, vocab_size), -1e5, device=memory_bank_bt.device, dtype=torch.float32 ) complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0 @@ -205,9 +212,10 @@ def _translate_batch_with_strategy( log_probs, attn = self._decode_and_generate( decoder_input, - memory_bank, + memory_bank_bt, batch, memory_lengths=memory_lengths, + memory_padding_mask=memory_padding_mask, src_map=src_map, step=step, batch_offset=decode_strategy.batch_offset, # type: ignore @@ -243,12 +251,15 @@ def _translate_batch_with_strategy( if any_finished: # Reorder states. - if isinstance(memory_bank, tuple): - memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) + if isinstance(memory_bank_bt, tuple): + memory_bank_bt = tuple( + x.index_select(0, select_indices) for x in memory_bank_bt + ) else: - memory_bank = memory_bank.index_select(1, select_indices) + memory_bank_bt = memory_bank_bt.index_select(0, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) + memory_padding_mask = memory_padding_mask.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) @@ -304,9 +315,10 @@ def _run_encoder(self, batch: dict[str, Any]) -> tuple: def _decode_and_generate( self, decoder_in: torch.Tensor, - memory_bank: torch.Tensor, + memory_bank_bt: torch.Tensor, batch: dict[str, Any], memory_lengths: torch.Tensor, + memory_padding_mask: torch.Tensor, src_map=None, step: Optional[int] = None, batch_offset: torch.LongTensor = None, @@ -315,13 +327,14 @@ def _decode_and_generate( Args: decoder_in (torch.Tensor): shape: (1, batch_size * beam_size, 1), due to kv_cache mechanism - memory_bank (torch.Tensor): shape: (padded_src_len, batch_size * beam_size, hidden_dim) + memory_bank_bt (torch.Tensor): shape: (batch_size * beam_size, padded_src_len, hidden_dim) batch (dict), keys: - src: Tuple(src, src_lengths) - src (torch.Tensor): shape of src: (padded_src_len, batch_size, 1) - src_lengths (torch.Tensor): shape of src_lengths: (batch_size,) - batch_size: int memory_lengths (torch.Tensor): shape: (batch_size * beam_size,) + memory_padding_mask (torch.Tensor): shape: (batch_size * beam_size, padded_src_len) src_map (torch.Tensor): None step (int): current step batch_offset (int): batch offset @@ -356,20 +369,13 @@ def _decode_and_generate( decoder_padding_mask = decoder_in.eq(self._tgt_pad_idx).squeeze( 2 ) # shape: (1, batch_size * beam_size) - memory_padding_mask = torch.arange( - 0, max(memory_lengths), device=memory_bank.device - ) >= memory_lengths.unsqueeze( - 1 - ) # shape: (batch_size * beam_size, padded_src_len) dec_out, dec_attn = self.model.decoder.forward( tgt=decoder_embedding, # shape: (batch_size * beam_size, 1, hidden_dim) tgt_key_padding_mask=decoder_padding_mask.transpose( 0, 1 ).contiguous(), # shape: (batch_size * beam_size, 1) - enc_out=memory_bank.transpose( - 0, 1 - ).contiguous(), # shape: (batch_size * beam_size, padded_src_len, hidden_dim) + enc_out=memory_bank_bt, # shape: (batch_size * beam_size, padded_src_len, hidden_dim) enc_key_padding_mask=memory_padding_mask, # shape: (batch_size * beam_size, padded_src_len) step=step, ) # shape: (batch_size, tgt_len-1, hidden_dim) From b57e1bf7b5f28de08c00fc0bdc88534648e2d9a9 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 08:48:18 +0000 Subject: [PATCH 5/8] feat(opennmt): Reuse the already available enc_key_padding_mask --- retrochimera/opennmt/modules/transformer_decoder.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/retrochimera/opennmt/modules/transformer_decoder.py b/retrochimera/opennmt/modules/transformer_decoder.py index 7021bd2..a22c4ea 100644 --- a/retrochimera/opennmt/modules/transformer_decoder.py +++ b/retrochimera/opennmt/modules/transformer_decoder.py @@ -631,14 +631,7 @@ def forward( ) dec_out = tgt - src_len = enc_key_padding_mask.eq(0).sum(dim=1).long() # shape: (batch_size,) - src_max_len = self.state["src"].shape[ - 0 - ] # self.state["src"] shape: (padded_src_len, batch_size, 1) - - src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze( - 1 - ) # shape: (batch_size, 1, padded_src_len) + src_pad_mask = enc_key_padding_mask.unsqueeze(1) # shape: (batch_size, 1, padded_src_len) tgt_pad_mask = tgt_key_padding_mask.unsqueeze(1) # shape: (batch_size, 1, padded_tgt_len/1) with_align = kwargs.pop("with_align", False) From 95e3dfee0933046e33227628294a2a1949b6967f Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 14:28:43 +0000 Subject: [PATCH 6/8] feat(opennmt): Snapshot predictions and scores to CPU once --- retrochimera/opennmt/decode/beam_search.py | 29 ++++++++++++++-------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/retrochimera/opennmt/decode/beam_search.py b/retrochimera/opennmt/decode/beam_search.py index c19b320..2d88de5 100644 --- a/retrochimera/opennmt/decode/beam_search.py +++ b/retrochimera/opennmt/decode/beam_search.py @@ -11,6 +11,7 @@ Modifications: 1. Simplified the structure by removing the `BeamSearchLM` and `GNMTGlobalScorer` class, and merged the `BeamSearchBase` class into the `BeamSearch` class. 2. Introduced the `customised_beam_search` attribute and corresponding logic to the `BeamSearch` class, enabling optimized beam search for retrosynthesis prediction. +3. Snapshot tensors to CPU in `update_finished` once instead of per each access. """ import torch @@ -220,32 +221,40 @@ def update_finished(self): if self.alive_attn is not None else None ) + + # Snapshot to CPU once instead of having this done implicitly per each access. + predictions_cpu = predictions.to("cpu", non_blocking=False) + topk_scores_cpu = self.topk_scores.to("cpu", non_blocking=False) + non_finished_batch = [] for i in range(self.is_finished.size(0)): # Batch level b = self._batch_offset[i] - finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) + finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1).tolist() # Store finished hypotheses for this batch. for j in finished_hyp: # Beam level: finished beam j in batch i if self.ratio > 0: - s = self.topk_scores[i, j] / (step + 1) + s = topk_scores_cpu[i, j] / (step + 1) if self.best_scores[b] < s: self.best_scores[b] = s if not self.customised_beam_search: self.hypotheses[b].append( ( - self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + topk_scores_cpu[i, j], + predictions_cpu[i, j, 1:], # Ignore start_token. attention[:, i, j, : self.memory_lengths[i]] if attention is not None else None, ) ) else: - if predictions[i, j, 1:].size(-1) == 0 or predictions[i, j, -2] != self.eos: + if ( + predictions_cpu[i, j, 1:].size(-1) == 0 + or predictions_cpu[i, j, -2].item() != self.eos + ): self.hypotheses[b].append( ( - self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + topk_scores_cpu[i, j], + predictions_cpu[i, j, 1:], # Ignore start_token. attention[:, i, j, : self.memory_lengths[i]] if attention is not None else None, @@ -256,13 +265,13 @@ def update_finished(self): if self.ratio > 0: pred_len = self.memory_lengths[i] * self.ratio finish_flag = ( - (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] - ) or self.is_finished[i].all() + (topk_scores_cpu[i, 0] / pred_len) <= self.best_scores[b] + ) or bool(self.is_finished[i].all()) else: if not self.customised_beam_search: finish_flag = self.top_beam_finished[i] != 0 else: - finish_flag = self.is_finished[i].all() # shape: (beam_size,) + finish_flag = bool(self.is_finished[i].all()) # shape: (beam_size,) if finish_flag and len(self.hypotheses[b]) >= self.n_best: best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True) From 71bf735886798b372259f6ff13f8cd4381f0fb6b Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Tue, 2 Jun 2026 14:53:33 +0000 Subject: [PATCH 7/8] chore(opennmt): Reformat with black --- retrochimera/opennmt/decode/beam_search.py | 6 +++--- retrochimera/opennmt/decode/translator.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/retrochimera/opennmt/decode/beam_search.py b/retrochimera/opennmt/decode/beam_search.py index 2d88de5..27189cf 100644 --- a/retrochimera/opennmt/decode/beam_search.py +++ b/retrochimera/opennmt/decode/beam_search.py @@ -264,9 +264,9 @@ def update_finished(self): # n_best hypotheses. if self.ratio > 0: pred_len = self.memory_lengths[i] * self.ratio - finish_flag = ( - (topk_scores_cpu[i, 0] / pred_len) <= self.best_scores[b] - ) or bool(self.is_finished[i].all()) + finish_flag = ((topk_scores_cpu[i, 0] / pred_len) <= self.best_scores[b]) or bool( + self.is_finished[i].all() + ) else: if not self.customised_beam_search: finish_flag = self.top_beam_finished[i] != 0 diff --git a/retrochimera/opennmt/decode/translator.py b/retrochimera/opennmt/decode/translator.py index c7bd656..1790899 100644 --- a/retrochimera/opennmt/decode/translator.py +++ b/retrochimera/opennmt/decode/translator.py @@ -190,10 +190,9 @@ def _translate_batch_with_strategy( memory_bank_bt = memory_bank.transpose(0, 1).contiguous() src_pad_len = memory_bank_bt.size(1) - memory_padding_mask = ( - torch.arange(0, src_pad_len, device=memory_bank_bt.device) - >= memory_lengths.unsqueeze(1) - ) + memory_padding_mask = torch.arange( + 0, src_pad_len, device=memory_bank_bt.device + ) >= memory_lengths.unsqueeze(1) complete_seq_log_prob = None if self.customised_beam_search: From 4828c578a3635c869786f041d3375bb57bd2563e Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Wed, 3 Jun 2026 09:38:36 +0000 Subject: [PATCH 8/8] doc(CHANGELOG): Add an entry for #17 --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c272fc..28abd24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. ### Changed -- Speed up localization model ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz]) +- Speed up localization model loading and inference ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz]) +- Speed up SMILES Transformer model inference ([#17](https://github.com/microsoft/retrochimera/pull/17)) ([@kmaziarz]) - Drop the explicit TensorBoard dependency ([#12](https://github.com/microsoft/retrochimera/pull/12)) ([@kmaziarz]) ### Added