diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c272fc..28abd24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. ### Changed -- Speed up localization model ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz]) +- Speed up localization model loading and inference ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz]) +- Speed up SMILES Transformer model inference ([#17](https://github.com/microsoft/retrochimera/pull/17)) ([@kmaziarz]) - Drop the explicit TensorBoard dependency ([#12](https://github.com/microsoft/retrochimera/pull/12)) ([@kmaziarz]) ### Added diff --git a/retrochimera/opennmt/decode/beam_search.py b/retrochimera/opennmt/decode/beam_search.py index c19b320..27189cf 100644 --- a/retrochimera/opennmt/decode/beam_search.py +++ b/retrochimera/opennmt/decode/beam_search.py @@ -11,6 +11,7 @@ Modifications: 1. Simplified the structure by removing the `BeamSearchLM` and `GNMTGlobalScorer` class, and merged the `BeamSearchBase` class into the `BeamSearch` class. 2. Introduced the `customised_beam_search` attribute and corresponding logic to the `BeamSearch` class, enabling optimized beam search for retrosynthesis prediction. +3. Snapshot tensors to CPU in `update_finished` once instead of per each access. """ import torch @@ -220,32 +221,40 @@ def update_finished(self): if self.alive_attn is not None else None ) + + # Snapshot to CPU once instead of having this done implicitly per each access. + predictions_cpu = predictions.to("cpu", non_blocking=False) + topk_scores_cpu = self.topk_scores.to("cpu", non_blocking=False) + non_finished_batch = [] for i in range(self.is_finished.size(0)): # Batch level b = self._batch_offset[i] - finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) + finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1).tolist() # Store finished hypotheses for this batch. for j in finished_hyp: # Beam level: finished beam j in batch i if self.ratio > 0: - s = self.topk_scores[i, j] / (step + 1) + s = topk_scores_cpu[i, j] / (step + 1) if self.best_scores[b] < s: self.best_scores[b] = s if not self.customised_beam_search: self.hypotheses[b].append( ( - self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + topk_scores_cpu[i, j], + predictions_cpu[i, j, 1:], # Ignore start_token. attention[:, i, j, : self.memory_lengths[i]] if attention is not None else None, ) ) else: - if predictions[i, j, 1:].size(-1) == 0 or predictions[i, j, -2] != self.eos: + if ( + predictions_cpu[i, j, 1:].size(-1) == 0 + or predictions_cpu[i, j, -2].item() != self.eos + ): self.hypotheses[b].append( ( - self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + topk_scores_cpu[i, j], + predictions_cpu[i, j, 1:], # Ignore start_token. attention[:, i, j, : self.memory_lengths[i]] if attention is not None else None, @@ -255,14 +264,14 @@ def update_finished(self): # n_best hypotheses. if self.ratio > 0: pred_len = self.memory_lengths[i] * self.ratio - finish_flag = ( - (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] - ) or self.is_finished[i].all() + finish_flag = ((topk_scores_cpu[i, 0] / pred_len) <= self.best_scores[b]) or bool( + self.is_finished[i].all() + ) else: if not self.customised_beam_search: finish_flag = self.top_beam_finished[i] != 0 else: - finish_flag = self.is_finished[i].all() # shape: (beam_size,) + finish_flag = bool(self.is_finished[i].all()) # shape: (beam_size,) if finish_flag and len(self.hypotheses[b]) >= self.n_best: best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True) diff --git a/retrochimera/opennmt/decode/translator.py b/retrochimera/opennmt/decode/translator.py index 3d55a66..1790899 100644 --- a/retrochimera/opennmt/decode/translator.py +++ b/retrochimera/opennmt/decode/translator.py @@ -13,6 +13,8 @@ 1. Simplified the structure by removing the `build_translator` function, and merged the `Inference` class into the `Translator` class. 2. Refactored the `_run_encoder`, `_decode_and_generate`, and `_translate_batch_with_strategy` methods to align with the `retrochimera.models.smiles_transformer.SmilesTransformerModel` class. 3. Introduced the `customised_beam_search` attribute and corresponding logic to the `Translator` class, enabling optimized beam search for retrosynthesis prediction. +4. Switched `torch.no_grad()` to `torch.inference_mode()` for better performance. +5. Optimized code for more efficient tensor operations and memory usage. """ from typing import Any, Optional @@ -114,7 +116,7 @@ def translate_batch(self, batch: dict[str, Any], attn_debug: bool = False) -> di - src_lengths (torch.Tensor): shape of src_lengths: (batch_size,) - batch_size: int """ - with torch.no_grad(): + with torch.inference_mode(): # TODO: support these blacklisted features decode_strategy = BeamSearch( pad=self._tgt_pad_idx, @@ -186,6 +188,20 @@ def _translate_batch_with_strategy( if fn_map_state is not None: self.model.decoder.map_state(fn_map_state, only_map_src=True) + memory_bank_bt = memory_bank.transpose(0, 1).contiguous() + src_pad_len = memory_bank_bt.size(1) + memory_padding_mask = torch.arange( + 0, src_pad_len, device=memory_bank_bt.device + ) >= memory_lengths.unsqueeze(1) + + complete_seq_log_prob = None + if self.customised_beam_search: + vocab_size = self._tgt_vocab_len + complete_seq_log_prob = torch.full( + (1, vocab_size), -1e5, device=memory_bank_bt.device, dtype=torch.float32 + ) + complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0 + # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): decoder_input = decode_strategy.current_predictions.view( @@ -195,9 +211,10 @@ def _translate_batch_with_strategy( log_probs, attn = self._decode_and_generate( decoder_input, - memory_bank, + memory_bank_bt, batch, memory_lengths=memory_lengths, + memory_padding_mask=memory_padding_mask, src_map=src_map, step=step, batch_offset=decode_strategy.batch_offset, # type: ignore @@ -205,21 +222,13 @@ def _translate_batch_with_strategy( if self.customised_beam_search: # modify the log_probs for finished sentences. - _, vocab_size = tuple(log_probs.shape) - bad_token_log_prob = -1e5 - - complete_seq_log_prob = (torch.ones((1, vocab_size)) * bad_token_log_prob).to( - log_probs.device - ) # shape: (1, vocab_size) - complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0 # shape: (1, vocab_size) - # Use this vector in the output for sequences which are complete. is_end_token = ( decoder_input.squeeze() == self._tgt_eos_idx ) # shape: (batch_size * beam_size,) log_prob_mask = is_end_token.unsqueeze(1) # shape: (batch_size * beam_size, 1) - log_probs = (log_prob_mask * complete_seq_log_prob) + ( - ~log_prob_mask * log_probs + log_probs = torch.where( + log_prob_mask, complete_seq_log_prob, log_probs ) # shape: (batch_size * beam_size, vocab_size) assert log_probs.dim() == 2, f"Expected 2D tensor, got {log_probs.dim()}" @@ -241,12 +250,15 @@ def _translate_batch_with_strategy( if any_finished: # Reorder states. - if isinstance(memory_bank, tuple): - memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) + if isinstance(memory_bank_bt, tuple): + memory_bank_bt = tuple( + x.index_select(0, select_indices) for x in memory_bank_bt + ) else: - memory_bank = memory_bank.index_select(1, select_indices) + memory_bank_bt = memory_bank_bt.index_select(0, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) + memory_padding_mask = memory_padding_mask.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) @@ -302,9 +314,10 @@ def _run_encoder(self, batch: dict[str, Any]) -> tuple: def _decode_and_generate( self, decoder_in: torch.Tensor, - memory_bank: torch.Tensor, + memory_bank_bt: torch.Tensor, batch: dict[str, Any], memory_lengths: torch.Tensor, + memory_padding_mask: torch.Tensor, src_map=None, step: Optional[int] = None, batch_offset: torch.LongTensor = None, @@ -313,13 +326,14 @@ def _decode_and_generate( Args: decoder_in (torch.Tensor): shape: (1, batch_size * beam_size, 1), due to kv_cache mechanism - memory_bank (torch.Tensor): shape: (padded_src_len, batch_size * beam_size, hidden_dim) + memory_bank_bt (torch.Tensor): shape: (batch_size * beam_size, padded_src_len, hidden_dim) batch (dict), keys: - src: Tuple(src, src_lengths) - src (torch.Tensor): shape of src: (padded_src_len, batch_size, 1) - src_lengths (torch.Tensor): shape of src_lengths: (batch_size,) - batch_size: int memory_lengths (torch.Tensor): shape: (batch_size * beam_size,) + memory_padding_mask (torch.Tensor): shape: (batch_size * beam_size, padded_src_len) src_map (torch.Tensor): None step (int): current step batch_offset (int): batch offset @@ -354,20 +368,13 @@ def _decode_and_generate( decoder_padding_mask = decoder_in.eq(self._tgt_pad_idx).squeeze( 2 ) # shape: (1, batch_size * beam_size) - memory_padding_mask = torch.arange( - 0, max(memory_lengths), device=memory_bank.device - ) >= memory_lengths.unsqueeze( - 1 - ) # shape: (batch_size * beam_size, padded_src_len) dec_out, dec_attn = self.model.decoder.forward( tgt=decoder_embedding, # shape: (batch_size * beam_size, 1, hidden_dim) tgt_key_padding_mask=decoder_padding_mask.transpose( 0, 1 ).contiguous(), # shape: (batch_size * beam_size, 1) - enc_out=memory_bank.transpose( - 0, 1 - ).contiguous(), # shape: (batch_size * beam_size, padded_src_len, hidden_dim) + enc_out=memory_bank_bt, # shape: (batch_size * beam_size, padded_src_len, hidden_dim) enc_key_padding_mask=memory_padding_mask, # shape: (batch_size * beam_size, padded_src_len) step=step, ) # shape: (batch_size, tgt_len-1, hidden_dim) diff --git a/retrochimera/opennmt/modules/transformer_decoder.py b/retrochimera/opennmt/modules/transformer_decoder.py index 7021bd2..a22c4ea 100644 --- a/retrochimera/opennmt/modules/transformer_decoder.py +++ b/retrochimera/opennmt/modules/transformer_decoder.py @@ -631,14 +631,7 @@ def forward( ) dec_out = tgt - src_len = enc_key_padding_mask.eq(0).sum(dim=1).long() # shape: (batch_size,) - src_max_len = self.state["src"].shape[ - 0 - ] # self.state["src"] shape: (padded_src_len, batch_size, 1) - - src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze( - 1 - ) # shape: (batch_size, 1, padded_src_len) + src_pad_mask = enc_key_padding_mask.unsqueeze(1) # shape: (batch_size, 1, padded_src_len) tgt_pad_mask = tgt_key_padding_mask.unsqueeze(1) # shape: (batch_size, 1, padded_tgt_len/1) with_align = kwargs.pop("with_align", False)