Add RNN Transducer Loss for CPU#1137
Conversation
9d6589a to
e2e6562
Compare
| def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): | ||
| """RNN Transducer Loss | ||
|
|
||
| Args: |
There was a problem hiding this comment.
I think the documentation could be improved a bit. It could also be useful to reference the paper.
| super().build_extension(ext) | ||
|
|
||
|
|
||
| _TRANSDUCER_NAME = '_warp_transducer' |
There was a problem hiding this comment.
This will get installed in global namespace, outside of torchaudio package directory.
Please put it in torchaudio package.
| MESSAGE(STATUS "Building static library with GPU support") | ||
|
|
||
| CUDA_ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cu) | ||
| IF (!Torch_FOUND) |
There was a problem hiding this comment.
If torch is not found, shouldn't it be failing?
| self.reduction = reduction | ||
| self.loss = _RNNT.apply | ||
|
|
||
| def forward(self, acts, labels, act_lens, label_lens): |
There was a problem hiding this comment.
If you don't want to copy-paste the docs from the functional you could reference it here within the documentation.
b6c4ce8 to
ca66151
Compare
|
Some TODOs:
Some follow-ups:
|
82b7186 to
456eefc
Compare
| # Test if example provided in README runs | ||
| # https://github.com/HawkAaron/warp-transducer | ||
|
|
||
| acts = torch.FloatTensor( |
There was a problem hiding this comment.
nit: use the factory function torch.tensor([xyz], dtype=torch.float) instead of the type constructor. Same applies to IntTensor.
f96089b to
299310c
Compare
| U = data["tgt_lengths"][b] | ||
| for t in range(gradients.shape[1]): | ||
| for u in range(gradients.shape[2]): | ||
| np.testing.assert_allclose( |
There was a problem hiding this comment.
self.assertEqual should be preferred
f18105a to
1d2c5db
Compare
|
Some more TODOs:
Some more follow-ups:
Error below also happens on master: |
64c8220 to
32e3398
Compare
| loss = rnnt_loss(acts, labels, act_length, label_length) | ||
| loss.backward() | ||
|
|
||
| def _test_costs_and_gradients( |
There was a problem hiding this comment.
This could be inlined since it only has one call-site and is pretty small (but that's not the reason to remove an abstraction necessarily).
32e3398 to
fddfbd1
Compare
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { | ||
| m.impl("rnnt_loss", &cpu_rnnt_loss); |
There was a problem hiding this comment.
@vincentqb Can you define a proper namespace? torchaudio::<something>::rnnt_loss
I am not sure how you want to move on, but if you have a plan to add different type of rnnt, then more descriptive name would work better later, like warprnnt
There was a problem hiding this comment.
Adding anonymous namespace in #1159 for the time being.
|
@vincentqb I update the followup description for things addressed in #1159 and #1161. Please stamp these PRs when you have time. For |
|
For C++ ABI issue ssee #880 |
* fdsa * Tutorial runs * clarify one scaler per convergence run * adjust sizes, dont run illustrative sections * satisfying ocd * MORE * fdsa * details * rephrase * fix formatting * move script to recipes * hopefully moved to recipes * fdsa * add amp_tutorial to toctree * amp_tutorial -> amp_recipe * looks like backtick highlights dont render in card_description * correct path for amp_recipe.html * arch notes and saving/restoring * formatting * fdsa * Clarify autograd-autocast interaction for custom ops * touchups Co-authored-by: Brian Johnson <brianjo@fb.com>
This pull request introduces
rnnt_lossandRNNTLossas a prototype intorchaudio.prototype.transducerusing HawkAaron's warp-transducer.Follow-up work detailed in #1240.
cc @astaff, internal, #1099