Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

### Changed

- Speed up localization model ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz])
- Speed up localization model loading and inference ([#15](https://github.com/microsoft/retrochimera/pull/15), [#16](https://github.com/microsoft/retrochimera/pull/16)) ([@kmaziarz])
- Speed up SMILES Transformer model inference ([#17](https://github.com/microsoft/retrochimera/pull/17)) ([@kmaziarz])
- Drop the explicit TensorBoard dependency ([#12](https://github.com/microsoft/retrochimera/pull/12)) ([@kmaziarz])

### Added
Expand Down
31 changes: 20 additions & 11 deletions retrochimera/opennmt/decode/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Modifications:
1. Simplified the structure by removing the `BeamSearchLM` and `GNMTGlobalScorer` class, and merged the `BeamSearchBase` class into the `BeamSearch` class.
2. Introduced the `customised_beam_search` attribute and corresponding logic to the `BeamSearch` class, enabling optimized beam search for retrosynthesis prediction.
3. Snapshot tensors to CPU in `update_finished` once instead of per each access.
"""
import torch

Expand Down Expand Up @@ -220,32 +221,40 @@ def update_finished(self):
if self.alive_attn is not None
else None
)

# Snapshot to CPU once instead of having this done implicitly per each access.
predictions_cpu = predictions.to("cpu", non_blocking=False)
topk_scores_cpu = self.topk_scores.to("cpu", non_blocking=False)

non_finished_batch = []
for i in range(self.is_finished.size(0)): # Batch level
b = self._batch_offset[i]
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1).tolist()
# Store finished hypotheses for this batch.
for j in finished_hyp: # Beam level: finished beam j in batch i
if self.ratio > 0:
s = self.topk_scores[i, j] / (step + 1)
s = topk_scores_cpu[i, j] / (step + 1)
if self.best_scores[b] < s:
self.best_scores[b] = s
if not self.customised_beam_search:
self.hypotheses[b].append(
(
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
topk_scores_cpu[i, j],
predictions_cpu[i, j, 1:], # Ignore start_token.
attention[:, i, j, : self.memory_lengths[i]]
if attention is not None
else None,
)
)
else:
if predictions[i, j, 1:].size(-1) == 0 or predictions[i, j, -2] != self.eos:
if (
predictions_cpu[i, j, 1:].size(-1) == 0
or predictions_cpu[i, j, -2].item() != self.eos
):
self.hypotheses[b].append(
(
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
topk_scores_cpu[i, j],
predictions_cpu[i, j, 1:], # Ignore start_token.
attention[:, i, j, : self.memory_lengths[i]]
if attention is not None
else None,
Expand All @@ -255,14 +264,14 @@ def update_finished(self):
# n_best hypotheses.
if self.ratio > 0:
pred_len = self.memory_lengths[i] * self.ratio
finish_flag = (
(self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]
) or self.is_finished[i].all()
finish_flag = ((topk_scores_cpu[i, 0] / pred_len) <= self.best_scores[b]) or bool(
self.is_finished[i].all()
)
else:
if not self.customised_beam_search:
finish_flag = self.top_beam_finished[i] != 0
else:
finish_flag = self.is_finished[i].all() # shape: (beam_size,)
finish_flag = bool(self.is_finished[i].all()) # shape: (beam_size,)

if finish_flag and len(self.hypotheses[b]) >= self.n_best:
best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)
Expand Down
57 changes: 32 additions & 25 deletions retrochimera/opennmt/decode/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
1. Simplified the structure by removing the `build_translator` function, and merged the `Inference` class into the `Translator` class.
2. Refactored the `_run_encoder`, `_decode_and_generate`, and `_translate_batch_with_strategy` methods to align with the `retrochimera.models.smiles_transformer.SmilesTransformerModel` class.
3. Introduced the `customised_beam_search` attribute and corresponding logic to the `Translator` class, enabling optimized beam search for retrosynthesis prediction.
4. Switched `torch.no_grad()` to `torch.inference_mode()` for better performance.
5. Optimized code for more efficient tensor operations and memory usage.
"""
from typing import Any, Optional

Expand Down Expand Up @@ -114,7 +116,7 @@ def translate_batch(self, batch: dict[str, Any], attn_debug: bool = False) -> di
- src_lengths (torch.Tensor): shape of src_lengths: (batch_size,)
- batch_size: int
"""
with torch.no_grad():
with torch.inference_mode():
# TODO: support these blacklisted features
decode_strategy = BeamSearch(
pad=self._tgt_pad_idx,
Expand Down Expand Up @@ -186,6 +188,20 @@ def _translate_batch_with_strategy(
if fn_map_state is not None:
self.model.decoder.map_state(fn_map_state, only_map_src=True)

memory_bank_bt = memory_bank.transpose(0, 1).contiguous()
src_pad_len = memory_bank_bt.size(1)
memory_padding_mask = torch.arange(
0, src_pad_len, device=memory_bank_bt.device
) >= memory_lengths.unsqueeze(1)

complete_seq_log_prob = None
if self.customised_beam_search:
vocab_size = self._tgt_vocab_len
complete_seq_log_prob = torch.full(
(1, vocab_size), -1e5, device=memory_bank_bt.device, dtype=torch.float32
)
complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0

# (3) Begin decoding step by step:
for step in range(decode_strategy.max_length):
decoder_input = decode_strategy.current_predictions.view(
Expand All @@ -195,31 +211,24 @@ def _translate_batch_with_strategy(

log_probs, attn = self._decode_and_generate(
decoder_input,
memory_bank,
memory_bank_bt,
batch,
memory_lengths=memory_lengths,
memory_padding_mask=memory_padding_mask,
src_map=src_map,
step=step,
batch_offset=decode_strategy.batch_offset, # type: ignore
) # (batch_size * beam_size, vocab)

if self.customised_beam_search:
# modify the log_probs for finished sentences.
_, vocab_size = tuple(log_probs.shape)
bad_token_log_prob = -1e5

complete_seq_log_prob = (torch.ones((1, vocab_size)) * bad_token_log_prob).to(
log_probs.device
) # shape: (1, vocab_size)
complete_seq_log_prob[:, self._tgt_eos_idx] = 0.0 # shape: (1, vocab_size)

# Use this vector in the output for sequences which are complete.
is_end_token = (
decoder_input.squeeze() == self._tgt_eos_idx
) # shape: (batch_size * beam_size,)
log_prob_mask = is_end_token.unsqueeze(1) # shape: (batch_size * beam_size, 1)
log_probs = (log_prob_mask * complete_seq_log_prob) + (
~log_prob_mask * log_probs
log_probs = torch.where(
log_prob_mask, complete_seq_log_prob, log_probs
) # shape: (batch_size * beam_size, vocab_size)

assert log_probs.dim() == 2, f"Expected 2D tensor, got {log_probs.dim()}"
Expand All @@ -241,12 +250,15 @@ def _translate_batch_with_strategy(

if any_finished:
# Reorder states.
if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank)
if isinstance(memory_bank_bt, tuple):
memory_bank_bt = tuple(
x.index_select(0, select_indices) for x in memory_bank_bt
)
else:
memory_bank = memory_bank.index_select(1, select_indices)
memory_bank_bt = memory_bank_bt.index_select(0, select_indices)

memory_lengths = memory_lengths.index_select(0, select_indices)
memory_padding_mask = memory_padding_mask.index_select(0, select_indices)

if src_map is not None:
src_map = src_map.index_select(1, select_indices)
Expand Down Expand Up @@ -302,9 +314,10 @@ def _run_encoder(self, batch: dict[str, Any]) -> tuple:
def _decode_and_generate(
self,
decoder_in: torch.Tensor,
memory_bank: torch.Tensor,
memory_bank_bt: torch.Tensor,
batch: dict[str, Any],
memory_lengths: torch.Tensor,
memory_padding_mask: torch.Tensor,
src_map=None,
step: Optional[int] = None,
batch_offset: torch.LongTensor = None,
Expand All @@ -313,13 +326,14 @@ def _decode_and_generate(

Args:
decoder_in (torch.Tensor): shape: (1, batch_size * beam_size, 1), due to kv_cache mechanism
memory_bank (torch.Tensor): shape: (padded_src_len, batch_size * beam_size, hidden_dim)
memory_bank_bt (torch.Tensor): shape: (batch_size * beam_size, padded_src_len, hidden_dim)
batch (dict), keys:
- src: Tuple(src, src_lengths)
- src (torch.Tensor): shape of src: (padded_src_len, batch_size, 1)
- src_lengths (torch.Tensor): shape of src_lengths: (batch_size,)
- batch_size: int
memory_lengths (torch.Tensor): shape: (batch_size * beam_size,)
memory_padding_mask (torch.Tensor): shape: (batch_size * beam_size, padded_src_len)
src_map (torch.Tensor): None
step (int): current step
batch_offset (int): batch offset
Expand Down Expand Up @@ -354,20 +368,13 @@ def _decode_and_generate(
decoder_padding_mask = decoder_in.eq(self._tgt_pad_idx).squeeze(
2
) # shape: (1, batch_size * beam_size)
memory_padding_mask = torch.arange(
0, max(memory_lengths), device=memory_bank.device
) >= memory_lengths.unsqueeze(
1
) # shape: (batch_size * beam_size, padded_src_len)

dec_out, dec_attn = self.model.decoder.forward(
tgt=decoder_embedding, # shape: (batch_size * beam_size, 1, hidden_dim)
tgt_key_padding_mask=decoder_padding_mask.transpose(
0, 1
).contiguous(), # shape: (batch_size * beam_size, 1)
enc_out=memory_bank.transpose(
0, 1
).contiguous(), # shape: (batch_size * beam_size, padded_src_len, hidden_dim)
enc_out=memory_bank_bt, # shape: (batch_size * beam_size, padded_src_len, hidden_dim)
enc_key_padding_mask=memory_padding_mask, # shape: (batch_size * beam_size, padded_src_len)
step=step,
) # shape: (batch_size, tgt_len-1, hidden_dim)
Expand Down
9 changes: 1 addition & 8 deletions retrochimera/opennmt/modules/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,14 +631,7 @@ def forward(
)

dec_out = tgt
src_len = enc_key_padding_mask.eq(0).sum(dim=1).long() # shape: (batch_size,)
src_max_len = self.state["src"].shape[
0
] # self.state["src"] shape: (padded_src_len, batch_size, 1)

src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze(
1
) # shape: (batch_size, 1, padded_src_len)
src_pad_mask = enc_key_padding_mask.unsqueeze(1) # shape: (batch_size, 1, padded_src_len)
tgt_pad_mask = tgt_key_padding_mask.unsqueeze(1) # shape: (batch_size, 1, padded_tgt_len/1)

with_align = kwargs.pop("with_align", False)
Expand Down
Loading