Skip to content

tile-ai/TileOPs

Repository files navigation

TileOPs: Operator Library for LLMs Built on TileLang

Installation | Getting Started | Documents

TileOPs is a high-performance operator library for large language models (LLMs) built on TileLang. It offers efficient, modular, and composable implementations for AI workloads, especially for LLMs.

⚠️ Status: TileOPs is under active and rapid development. APIs and features may change.

What TileOPs is for:

  • Out-of-the-box Operator Library: A growing collection of production-ready operators commonly used in LLM workloads, designed with clear abstractions and modular building blocks. These operators can be used directly or easily extended for custom research and system integration.
  • Efficient Attention Kernels for LLMs: Highly optimized attention implementations, including MHA/GQA (implemented FA2 on Ampere-like GPUs and FA3 on Hopper), DeepSeek-MLA, and DeepSeek-DSA.
  • Reference Implementation for TileLang: TileOPs acts as a canonical reference implementation for writing performant and maintainable kernels in TileLang. It demonstrates best practices in tiling strategies, memory hierarchy utilization, and warp-/block-level coordination, making it a practical learning resource for compiler and kernel developers.

The core features of TileOPs include:

  • Auto-Tuning: Built-in auto-tuning support to explore tile sizes, pipelines, and scheduling parameters, enabling kernels to adapt efficiently to different GPU architectures and workload characteristics with minimal manual effort.
  • CUDA-Graph and torch.compile Compatibility: TileOPs APIs are fully compatible with CUDA-Graph capture and PyTorch torch.compile, allowing seamless integration into modern training and inference pipelines with reduced launch overhead and improved end-to-end performance.
  • Lightweight Dependencies: TileOPs depends only on TileLang, PyTorch, and einops, keeping the software stack minimal and easy to integrate.

📦 Install with pip

Prerequisites

  • Python >= 3.10
  • PyTorch >= 2.1
  • CUDA Toolkit (required — this is a GPU kernel project)
  • A CUDA-capable NVIDIA GPU
    • Tested architectures: Ampere (SM_80, SM_86) and Hopper (SM_90)
    • Other architectures may work but are not tested
  • TileLang == 0.1.8

Method 1: Install from PyPI

pip install tileops

Method 2: Install from source (for development)

git clone https://github.com/tile-ai/TileOPs
cd TileOPs
make install                # installs dev dependencies and pre-commit hooks

Note

If you have CUDA and TileLang already installed system-wide and encounter build issues, try: PIP_NO_BUILD_ISOLATION=1 pip install -e '.[dev]' -v && pre-commit install This disables pip's build isolation so it can find your existing CUDA/TileLang installation.

After installing, verify with a test run:

python -m pytest tests/ -q  # run the test suite (requires a CUDA GPU)

🚀 Quick Start

import torch
from tileops.ops import GemmOp

# Define matrix dimensions: C = A @ B, where A is (M, K) and B is (K, N)
M, N, K = 1024, 1024, 512
dtype = torch.float16

# Instantiate the op
gemm = GemmOp(M, N, K, dtype=dtype)

# Generate inputs
A = torch.randn(M, K, device="cuda", dtype=dtype)
B = torch.randn(K, N, device="cuda", dtype=dtype)

# Run the operator
C = gemm(A, B)

Documents

Hierarchical APIs

TileOPs is structured around two hierarchical key concepts, each representing a distinct level of abstraction. Higher-level components are composed from, or delegate execution to, the next lower level.

  • Op: determines the implementation for a given shape and hardware, dispatching to the correct Kernel and providing unit test and benchmark. Ops are fully compatible with CUDA-Graph capture and torch.compile.
  • Kernel: TileLang-based kernels with hardware-specific optimizations.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors