predict_action function performs pyTorch conversion using GPU#37
predict_action function performs pyTorch conversion using GPU#37demobo-com wants to merge 1 commit intoInterbotix:feat/improve-predictfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR optimizes predict_action by moving the tensor device transfer to before data transformations so that normalization and reshaping occur on the GPU, improving inference performance under CPU load.
- Moves
observation[name].to(device)to the start of the conversion loop and removes the redundant transfer at the end - Ensures all type casting, normalization, permute, and unsqueeze operations happen on the CUDA device
Comments suppressed due to low confidence (1)
lerobot/common/robot_devices/control_utils.py:112
- Add a unit test to verify that after
predict_action, all observation tensors are on the specified device and have the correct shape and dtype.
observation[name] = observation[name].to(device)
| observation[name] = observation[name].to(device) | ||
| if "image" in name: | ||
| observation[name] = observation[name].type(torch.float32) / 255 | ||
| observation[name] = observation[name].permute(2, 0, 1).contiguous() | ||
| observation[name] = observation[name].unsqueeze(0) |
There was a problem hiding this comment.
You can combine device transfer, dtype conversion, and normalization into a single chained call to reduce intermediate allocations, e.g.: observation[name] = observation[name].to(device=device, dtype=torch.float32).div(255).permute(2,0,1).unsqueeze(0).
| observation[name] = observation[name].to(device) | |
| if "image" in name: | |
| observation[name] = observation[name].type(torch.float32) / 255 | |
| observation[name] = observation[name].permute(2, 0, 1).contiguous() | |
| observation[name] = observation[name].unsqueeze(0) | |
| if "image" in name: | |
| observation[name] = observation[name].to(device=device, dtype=torch.float32).div(255).permute(2, 0, 1).unsqueeze(0) | |
| else: | |
| observation[name] = observation[name].to(device).unsqueeze(0) |
| for name in observation: | ||
| observation[name] = observation[name].to(device) | ||
| if "image" in name: | ||
| observation[name] = observation[name].type(torch.float32) / 255 |
There was a problem hiding this comment.
[nitpick] Consider using the .float() alias instead of .type(torch.float32) for readability and consistency with common PyTorch code style.
| observation[name] = observation[name].type(torch.float32) / 255 | |
| observation[name] = observation[name].float() / 255 |
What this does
(⚡️ Performance)
Optimizes predict_action function's observations' pyTorch conversions by using CUDA GPU
How to checkout & try? (for the reviewer)
Run the same ACT policy evaluation in these 3 different scenarios:
Examples: