Inference Speed Optimization¶
Fast inference is critical for deploying DNA language models in real-world applications, from large-scale genomic screening to interactive analysis. This guide covers key techniques to accelerate model inference.
1. Use Half-Precision (FP16/BF16)¶
The Problem¶
Running inference in full 32-bit precision (FP32) is often unnecessarily slow and memory-intensive.
How to Optimize¶
Just like in training, using 16-bit floating-point numbers can provide a significant speedup for inference, especially on GPUs with Tensor Cores.
- FP16: Best for general-purpose speedup.
- BF16: Best for stability on Ampere and newer GPUs.
You can load the model directly in half-precision.
import torch
from dnallm.utils.load import load_model_and_tokenizer
model, tokenizer = load_model_and_tokenizer(
"zhihan1996/DNABERT-2-117M",
model_type="bert",
torch_dtype=torch.float16, # Use float16 for inference
device_map="auto" # Automatically move to GPU
)
# The model is now on the GPU in FP16 format
2. Batching Inference Requests¶
The Problem¶
Processing sequences one by one is highly inefficient. The overhead of launching the model for a single sequence dominates the actual computation time.
How to Optimize¶
Group multiple DNA sequences together and process them as a single batch. This allows the GPU to perform computations in parallel, dramatically increasing throughput.
import torch
dna_sequences = [
"GATTACA" * 10,
"ACGT" * 20,
"TTTAAA" * 15
]
# Tokenize all sequences together with padding
inputs = tokenizer(
dna_sequences,
return_tensors="pt",
padding=True, # Pad to the length of the longest sequence in the batch
truncation=True
).to(model.device)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state
print("Processed batch of size:", len(dna_sequences))
print("Output shape:", embeddings.shape)
3. Compile the Model with torch.compile
¶
The Problem¶
Standard PyTorch execution involves Python overhead that can slow down model execution.
How to Optimize¶
torch.compile()
is a feature in PyTorch 2.0+ that JIT (Just-In-Time) compiles your model into optimized kernel code. It can provide significant speedups (1.3x-2x) with a single line of code.
# Before your inference loop, compile the model
compiled_model = torch.compile(model)
# Use the compiled model for inference
with torch.no_grad():
outputs = compiled_model(**inputs)
Note: The first run after torch.compile()
will be slow as the compilation happens. Subsequent runs will be much faster.