Torch compile caching for inference speed
Blog post from Replicate
Replicate has implemented caching for torch.compile artifacts to significantly reduce boot times for PyTorch models, enhancing inference speed by 2-3 times for models such as black-forest-labs/flux-kontext-dev and prunaai/flux-schnell. The torch.compile function optimizes model performance by tracing and compiling code during the first call, which incurs initial overhead but results in faster subsequent executions. By caching these compiled artifacts across model container lifecycles, cold boot times have improved dramatically, with reductions of up to 62% for some models. The caching system functions similarly to CI/CD systems, reusing cached artifacts when available, and updating the cache as needed when containers gracefully shut down. More detailed guidance on using torch.compile can be found in Replicate's documentation and the official PyTorch tutorial.