-
Notifications
You must be signed in to change notification settings - Fork 1
Claude/refactor pytorch codegen zr oi f #46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Add LayerCodeSpec dataclass for structured code specifications - Create TemplateManager for Jinja2 template loading and caching - Add get_pytorch_code_spec() and get_tensorflow_code_spec() methods to NodeDefinition - Create template directory structure for PyTorch and TensorFlow - Add base layer templates for both frameworks
- Create layer templates for conv2d, linear, maxpool, flatten, relu, softmax, dropout, batchnorm, attention, add, concat - Create file templates for model.py, train.py, dataset.py, config.py - Templates preserve good coding practices: reusable classes, comprehensive documentation, proper shape annotations
- Add base utilities (topological sort, edge map builder) - Create PyTorchCodeOrchestrator class - Implement template-driven code generation - Support skip connections in forward pass - Preserve all existing features (adaptive hyperparameters, test code generation) - Return same API format as original codegen
- Modify generate_pytorch_code() to delegate to PyTorchCodeOrchestrator - Preserve exact same function signature and return type - Maintain 100% backward compatibility with views - Mark legacy code as preserved for reference - Complete refactor: template-based, extensible, non-monolithic
- Create layer templates for conv2d, linear, maxpool, flatten, dropout, batchnorm, add, concat - Create file templates for model.py, train.py, dataset.py, config.py - Adapt templates for TensorFlow/Keras API and NHWC format - Maintain same documentation quality as PyTorch
- Add get_tensorflow_code_spec() to Conv2D, Linear, MaxPool2D, Flatten, Dropout, BatchNorm2D, Add, Concat - Adapt for TensorFlow/Keras API (filters vs channels, NHWC format) - Mirror PyTorch implementation pattern for consistency
- Add TensorFlowCodeOrchestrator class mirroring PyTorch structure - Implement template-driven code generation for TensorFlow/Keras - Support NHWC format and TensorFlow-specific patterns - Handle skip connections in forward pass - Return same API format as original TensorFlow codegen
- Modify generate_tensorflow_code() to delegate to TensorFlowCodeOrchestrator - Preserve exact same function signature and return type - Maintain 100% backward compatibility with views - Mark legacy code as preserved for reference - Complete TensorFlow refactor: template-based, extensible, non-monolithic
- Fix: Remove explicit .forward() and .call() method calls (use __call__ instead) - Fix: PyTorch layers should be called directly, not via .forward() - Fix: TensorFlow layers should be called directly, not via .call() - Add: Base orchestrator class to eliminate WET code (DRY principle) - Prepare: Foundation for refactoring both orchestrators to inherit from base
- Fix: Change import from .pytorch_codegen to .enhanced_pytorch_codegen - The actual file is named enhanced_pytorch_codegen.py, not pytorch_codegen.py
- Remove imports of classes that don't exist (GroupBlockShapeComputer, safe_get_shape_data, etc.) - These were only used in legacy code that never executes due to early delegation - Fixes ModuleNotFoundError on server startup
…age and enhance edge mapping for skip connections
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors PyTorch code generation to use a template-based architecture, adds new PyTorch layer nodes (LSTM, GRU, Embedding, Conv1D, Conv3D, pooling layers), implements corresponding TensorFlow nodes, and updates the Gemini service import.
- Migrates PyTorch and TensorFlow code generation from string concatenation to Jinja2 templates
- Adds recurrent (LSTM, GRU), embedding, and additional convolution/pooling layers for PyTorch
- Introduces template-based code generation infrastructure with orchestrators for both frameworks
- Updates Gemini AI service import path
Reviewed changes
Copilot reviewed 75 out of 75 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| project/frontend/src/lib/nodes/definitions/pytorch/*.ts | New PyTorch layer definitions (LSTM, GRU, Embedding, Conv1D/3D, pooling) |
| project/frontend/src/lib/nodes/definitions/pytorch/index.ts | Exports for new layer nodes |
| project/block_manager/services/tensorflow_codegen.py | Delegates to new orchestrator, removes legacy code |
| project/block_manager/services/nodes/tensorflow/*.py | TensorFlow node definitions with code spec methods |
| project/block_manager/services/nodes/pytorch/*.py | PyTorch node definitions with code spec methods |
| project/block_manager/services/nodes/templates/**/*.jinja2 | Jinja2 templates for layer code generation |
| project/block_manager/services/nodes/templates/manager.py | Template loading and caching infrastructure |
| project/block_manager/services/nodes/base.py | Adds LayerCodeSpec dataclass and code spec methods |
| project/block_manager/services/gemini_service.py | Updates import path for Gemini library |
| project/block_manager/services/enhanced_pytorch_codegen.py | Delegates to new orchestrator |
| project/block_manager/services/codegen/*.py | New orchestrators and generators for template-based code generation |
No description provided.