Skip to content

[not for land] example of simple FP8 UEX with stateful scaling#44

Open
vkuzo wants to merge 1 commit into
mainfrom
fp8_v2
Open

[not for land] example of simple FP8 UEX with stateful scaling#44
vkuzo wants to merge 1 commit into
mainfrom
fp8_v2

Conversation

@vkuzo

@vkuzo vkuzo commented Apr 28, 2023

Copy link
Copy Markdown
Collaborator

This is an example of implementing basic fp8 support with a Python tensor subclass.

tl;dr;

  1. FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale, and a flavor (e4m3/e5m2)
  2. FP8Tensor.__torch__dispatch knows how to add gradients, but converts to fp32 for everything else
  3. FP8Linear is a module which can do stateful delayed scaling. User is expected to manually swap their linears to something like this.

Note: E4M3 support has not been numerically validated, and E5M2 support is not there at all
Note: No testing other than the bare bones at the bottom of the PR has been done.
Note: scaling is not implemented, currently it's just scales of 1.0 everywhere

This is an example of implementing basic fp8 support with a Python tensor subclass.

tl;dr;
1. FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale, and a flavor (e4m3/e5m2)
2. FP8Tensor.__torch__dispatch knows how to add gradients, but converts to fp32 for everything else
3. FP8Linear is a module which can do stateful delayed scaling. User is expected to manually swap their linears to something like this.

Note: E4M3 support has not been numerically validated, and E5M2 support is not there at all
Note: No testing other than the bare bones at the bottom of the PR has been done.
Note: scaling is not implemented, currently it's just scales of 1.0 everywhere
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant