Skip to content

inference/mutagenesis API

dnallm.inference.mutagenesis

In Silico Mutagenesis Analysis Module.

This module provides tools for evaluating the impact of sequence mutations on model predictions, including single nucleotide polymorphisms ( SNPs), deletions, insertions, and other sequence variations.

Classes

Mutagenesis

Mutagenesis(model, tokenizer, config)

Class for evaluating in silico mutagenesis.

This class provides methods to analyze how sequence mutations affect model predictions, including single base substitutions, deletions, and insertions. It can be used to identify important positions in DNA sequences and understand model interpretability.

Attributes:
    model: Fine-tuned model for prediction
    tokenizer: Tokenizer for the model
            config: Configuration object containing task settings and
        inference parameters
    sequences: Dictionary containing original and mutated sequences
    dataloader: DataLoader for batch processing of sequences

Initialize Mutagenesis class.

Parameters:

Name Type Description Default
model

Fine-tuned model for making predictions

required
tokenizer

Tokenizer for encoding DNA sequences

required
config dict

Configuration object containing task settings and inference parameters

required
Source code in dnallm/inference/mutagenesis.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(self, model, tokenizer, config: dict):
    """Initialize Mutagenesis class.

    Args:
        model: Fine-tuned model for making predictions
        tokenizer: Tokenizer for encoding DNA sequences
        config: Configuration object containing task settings and
            inference parameters
    """

    self.model = model
    self.tokenizer = tokenizer
    self.config = config
    self.sequences = None
Functions
clm_evaluate
clm_evaluate()

Calculate sequence log-probability using causal language modeling.

This method computes the log-probability of each sequence under a causal language model by summing the log probabilities of each token given its preceding context.

Returns:

Type Description
list[float]

List of log-probabilities for each sequence

Source code in dnallm/inference/mutagenesis.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
@torch.no_grad()
def clm_evaluate(self) -> list[float]:
    """Calculate sequence log-probability using causal language modeling.

    This method computes the log-probability of each sequence under a
    causal language model by summing the log probabilities of each token
    given its preceding context.

    Returns:
        List of log-probabilities for each sequence
    """
    all_logprobs = []
    model = self.model
    tokenizer = self.tokenizer
    for seq in tqdm(self.sequences["sequence"], desc="Inferring"):
        toks = tokenizer(
            seq, return_tensors="pt", add_special_tokens=True
        ).to(model.device)
        input_ids = toks["input_ids"]
        outputs = model(**toks)
        logits = outputs.logits  # (1, L, V)

        # shift for causal LM: predict token t given tokens < t
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
        token_logps = log_probs.gather(
            -1, shift_labels.unsqueeze(-1)
        ).squeeze(-1)  # (1, L-1)
        seq_logp = float(token_logps.sum().item())
        all_logprobs.append(seq_logp)
    return all_logprobs
evaluate
evaluate(strategy='last')

Evaluate the impact of mutations on model predictions.

This method runs predictions on all mutated sequences and compares them with the original sequence to calculate mutation effects.

Parameters:

Name Type Description Default
strategy str | int

Strategy for selecting the score from the log fold change - "first": Use the first log fold change - "last": Use the last log fold change - "sum": Use the sum of log fold changes - "mean": Use the mean of log fold changes

'last'
- "max"

Use the index of the maximum raw score to select the log fold change - int: Use the log fold change at the specified index

required

Returns:

Type Description
list[dict]

Dictionary containing predictions and metadata for all sequences:

list[dict]
  • 'raw': Original sequence predictions and metadata
list[dict]
  • mutation names: Individual mutation results with scores and log fold changes
Source code in dnallm/inference/mutagenesis.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def evaluate(self, strategy: str | int = "last") -> list[dict]:
    """Evaluate the impact of mutations on model predictions.

    This method runs predictions on all mutated sequences and compares them
    with the original sequence to calculate mutation effects.

    Args:
        strategy: Strategy for selecting the score from the log fold change
            - "first": Use the first log fold change
            - "last": Use the last log fold change
            - "sum": Use the sum of log fold changes
            - "mean": Use the mean of log fold changes
        - "max": Use the index of the maximum raw score to select the log
            fold change
            - int: Use the log fold change at the specified index

    Returns:
        Dictionary containing predictions and metadata for all sequences:
        - 'raw': Original sequence predictions and metadata
        - mutation names: Individual mutation results with scores and log
            fold changes
    """
    # Load predictor
    inference_engine = self.get_inference_engine(
        self.model, self.tokenizer
    )
    # Special models list
    sp_models = ["evo1", "evo2", "megadna"]
    sp_model_found = False
    # Do prediction
    all_predictions = {}
    for sp_model in sp_models:
        if sp_model in str(self.model).lower():
            logits = inference_engine.scoring(
                self.dataloader,
                reduce_method=strategy if strategy == "sum" else "mean",
            )
            raw_pred = logits[0]["Score"]
            mut_preds = [logit["Score"] for logit in logits[1:]]
            self.config["task"].task_type = "generation"
            sp_model_found = True
            break
    if not sp_model_found:
        if self.config["task"].task_type == "mask":
            logits = self.mlm_evaluate()
        elif self.config["task"].task_type == "generation":
            logits = self.clm_evaluate()
        else:
            logits, _, _ = inference_engine.batch_infer(
                self.dataloader, do_pred=False
            )
        logits = logits[0] if isinstance(logits, tuple) else logits
        # Get the raw predictions
        raw_pred = (
            logits[0].numpy()
            if isinstance(logits, torch.Tensor)
            else logits[0]
        )
        # Get the mutated predictions
        mut_preds = (
            logits[1:].numpy()
            if isinstance(logits, torch.Tensor)
            else logits[1:]
        )
    for i, mut_pred in tqdm(
        enumerate(mut_preds), desc="Evaluating mutations"
    ):
        # Get the mutated name
        mut_name = self.sequences["name"][i + 1]
        # Get the mutated sequence
        mut_seq = self.sequences["sequence"][i + 1]
        # Compare the predictions
        raw_score, mut_score, logfc, diff = self.pred_comparison(
            raw_pred, mut_pred
        )
        # Store the results
        if "raw" not in all_predictions:
            all_predictions["raw"] = {
                "sequence": self.sequences["sequence"][0],
                "pred": raw_score,
                "logfc": np.zeros(len(raw_score)),
                "diff": np.zeros(len(raw_score)),
                "score": 0.0,
            }
        all_predictions[mut_name] = {
            "sequence": mut_seq,
            "pred": mut_score,
            "logfc": logfc,
            "diff": diff,
        }
        # Get final score
        if strategy == "first":
            score = logfc[0]
        elif strategy == "last":
            score = logfc[-1]
        elif strategy == "sum":
            score = np.sum(logfc)
        elif strategy == "mean":
            score = np.mean(logfc)
        elif strategy == "max":
            idx = raw_score.index(max(raw_score))
            score = logfc[idx]
        elif isinstance(strategy, int):
            score = logfc[strategy]
        all_predictions[mut_name]["score"] = score

    return all_predictions
get_inference_engine
get_inference_engine(model, tokenizer)

Create an inference engine object for the model.

Parameters:

Name Type Description Default
model

The model to be used for inference

required
tokenizer

The tokenizer to be used for encoding sequences

required

Returns:

Name Type Description
DNAInference DNAInference

The inference engine object configured with the given model and tokenizer

Source code in dnallm/inference/mutagenesis.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def get_inference_engine(self, model, tokenizer) -> DNAInference:
    """Create an inference engine object for the model.

    Args:
        model: The model to be used for inference
        tokenizer: The tokenizer to be used for encoding sequences

    Returns:
        DNAInference: The inference engine object configured with the given
            model and tokenizer
    """

    inference_engine = DNAInference(
        model=model, tokenizer=tokenizer, config=self.config
    )

    return inference_engine
mlm_evaluate
mlm_evaluate()

Calculate pseudo-log-likelihood score using masked token prediction.

This method computes the pseudo-log-likelihood (PLL) score for each sequence by iteratively masking each token and predicting it using the model. The PLL score is the sum of the log probabilities of the true tokens given the masked context.

Returns:

Type Description
list[float]

List of pseudo-log-likelihood scores for each sequence

Source code in dnallm/inference/mutagenesis.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@torch.no_grad()
def mlm_evaluate(self) -> list[float]:
    """Calculate pseudo-log-likelihood score using masked token prediction.

    This method computes the pseudo-log-likelihood (PLL) score for each
    sequence by iteratively masking each token and predicting it using the
    model. The PLL score is the sum of the log probabilities of the true
    tokens given the masked context.

    Returns:
        List of pseudo-log-likelihood scores for each sequence
    """
    all_logprobs = []
    model = self.model
    tokenizer = self.tokenizer
    for seq in tqdm(self.sequences["sequence"], desc="Inferring"):
        toks = tokenizer(
            seq, return_tensors="pt", add_special_tokens=True
        ).to(model.device)
        input_ids = toks["input_ids"].clone()
        seq_len = input_ids.size(1)
        total = 0.0

        for i in range(seq_len):
            tok_id = input_ids[0, i].item()
            if tok_id in tokenizer.all_special_ids:
                continue
            masked = input_ids.clone()
            masked[0, i] = tokenizer.mask_token_id
            masked_inputs = {
                "input_ids": masked,
                "attention_mask": toks["attention_mask"],
            }
            outputs = model(**masked_inputs)
            logits = outputs.logits
            logp = torch.nn.functional.log_softmax(logits[0, i], dim=-1)
            total += float(logp[tok_id].item())
        all_logprobs.append(total)
    return all_logprobs
mutate_sequence
mutate_sequence(
    sequence,
    batch_size=1,
    replace_mut=True,
    include_n=False,
    delete_size=0,
    fill_gap=False,
    insert_seq=None,
    lowercase=False,
    do_encode=True,
)

Generate dataset from sequences with various mutation types.

This method creates mutated versions of the input sequence including: - Single base substitutions (A, C, G, T, optionally N) - Deletions of specified size - Insertions of specified sequences - Case transformations

Parameters:

Name Type Description Default
sequence

Single sequence for mutagenesis

required
batch_size int

Batch size for DataLoader

1
replace_mut bool

Whether to perform single base substitutions

True
include_n bool

Whether to include N base in substitutions

False
delete_size int

Size of deletions to create (0 for no deletions)

0
fill_gap bool

Whether to fill deletion gaps with N bases

False
insert_seq str | None

Sequence to insert at various positions

None
lowercase bool

Whether to convert sequences to lowercase

False
do_encode bool

Whether to encode sequences for the model

True

Returns:

Type Description

None (modifies internal state)

Source code in dnallm/inference/mutagenesis.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def mutate_sequence(
    self,
    sequence,
    batch_size: int = 1,
    replace_mut: bool = True,
    include_n: bool = False,
    delete_size: int = 0,
    fill_gap: bool = False,
    insert_seq: str | None = None,
    lowercase: bool = False,
    do_encode: bool = True,
):
    """Generate dataset from sequences with various mutation types.

    This method creates mutated versions of the input sequence including:
    - Single base substitutions (A, C, G, T, optionally N)
    - Deletions of specified size
    - Insertions of specified sequences
    - Case transformations

    Args:
        sequence: Single sequence for mutagenesis
        batch_size: Batch size for DataLoader
        replace_mut: Whether to perform single base substitutions
        include_n: Whether to include N base in substitutions
        delete_size: Size of deletions to create (0 for no deletions)
        fill_gap: Whether to fill deletion gaps with N bases
        insert_seq: Sequence to insert at various positions
        lowercase: Whether to convert sequences to lowercase
        do_encode: Whether to encode sequences for the model

    Returns:
        None (modifies internal state)
    """
    # Get the inference config
    pred_config = self.config["inference"]
    # Define the dataset
    sequences = {"name": ["raw"], "sequence": [sequence]}
    # Create mutated sequences
    if replace_mut:
        if include_n:
            base_map = ["A", "C", "G", "T", "N"]
        else:
            base_map = ["A", "C", "G", "T"]
        # Mutate sequence
        for i, base in enumerate(sequence):
            for mut_base in base_map:
                if base != mut_base:
                    name = f"mut_{i}_{base}_{mut_base}"
                    mutated_sequence = (
                        sequence[:i] + mut_base + sequence[i + 1 :]
                    )
                    sequences["name"].append(name)
                    sequences["sequence"].append(mutated_sequence)
    # Delete mutations
    if delete_size > 0:
        for i in range(len(sequence) - delete_size + 1):
            name = f"del_{i}_{delete_size}"
            if fill_gap:
                mutated_sequence = (
                    sequence[:i]
                    + "N" * delete_size
                    + sequence[i + delete_size :]
                )
            else:
                mutated_sequence = (
                    sequence[:i] + sequence[i + delete_size :]
                )
            sequences["name"].append(name)
            sequences["sequence"].append(mutated_sequence)
    # Insert mutations
    if insert_seq is not None:
        for i in range(len(sequence) + 1):
            name = f"ins_{i}_{insert_seq}"
            mutated_sequence = sequence[:i] + insert_seq + sequence[i:]
            sequences["name"].append(name)
            sequences["sequence"].append(mutated_sequence)
    # Lowercase sequences
    if lowercase:
        sequences["sequence"] = [
            seq.lower() for seq in sequences["sequence"]
        ]
    # Create dataset
    if len(sequences["sequence"]) > 0:
        ds = Dataset.from_dict(sequences)
        dataset = DNADataset(
            ds, self.tokenizer, max_length=pred_config.max_length
        )
        self.sequences = sequences
    # Encode sequences
    if do_encode:
        dataset.encode_sequences(remove_unused_columns=True)
    # Create DataLoader
    if batch_size <= 1:
        batch_size = pred_config.batch_size
    print(batch_size)
    self.dataloader: DataLoader = DataLoader(
        dataset, batch_size=batch_size, num_workers=pred_config.num_workers
    )
plot
plot(preds, show_score=False, save_path=None)

Plot the mutagenesis analysis results.

    This method generates visualizations of mutation effects,
typically as heatmaps,
    bar charts and
line plots showing how different mutations affect model predictions

at various positions.

Parameters:

Name Type Description Default
preds dict

Dictionary containing model predicted scores and metadata

required
show_score bool

Whether to show the score values on the plot save_path: Path to save the plot. If None, plot will be shown interactively

False

Returns:

Type Description
None

Plot object

Source code in dnallm/inference/mutagenesis.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def plot(
    self,
    preds: dict,
    show_score: bool = False,
    save_path: str | None = None,
) -> None:
    """Plot the mutagenesis analysis results.

            This method generates visualizations of mutation effects,
        typically as heatmaps,
            bar charts and
        line plots showing how different mutations affect model predictions
    at various positions.

    Args:
        preds: Dictionary containing model predicted scores and metadata
        show_score: Whether to show the score values on the plot
                    save_path: Path to save the plot. If None,
            plot will be shown interactively

    Returns:
        Plot object
    """
    if save_path:
        suffix = os.path.splitext(save_path)[-1]
        if suffix:
            outfile = save_path
        else:
            outfile = os.path.join(save_path, ".pdf")
    else:
        outfile = None
    # Plot heatmap
    pmut = plot_muts(preds, show_score=show_score, save_path=outfile)
    return pmut
pred_comparison
pred_comparison(raw_pred, mut_pred)

Compare raw and mutated predictions.

This method calculates the difference between predictions on the original sequence and mutated sequences, providing insights into mutation effects.

    Args:
        raw_pred: Raw predictions from the original sequence
        mut_pred: Predictions from the mutated sequence

    Returns:
        Tuple containing (raw_score, mut_score, logfc):
- raw_score: Processed scores from original sequence
- mut_score: Processed scores from mutated sequence
- logfc: Log fold change between mutated and original scores

    Raises:
        ValueError: If task type is not supported
Source code in dnallm/inference/mutagenesis.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def pred_comparison(self, raw_pred, mut_pred):
    """Compare raw and mutated predictions.

    This method calculates the difference between predictions on the
    original sequence and mutated sequences, providing insights into
    mutation effects.

            Args:
                raw_pred: Raw predictions from the original sequence
                mut_pred: Predictions from the mutated sequence

            Returns:
                Tuple containing (raw_score, mut_score, logfc):
        - raw_score: Processed scores from original sequence
        - mut_score: Processed scores from mutated sequence
        - logfc: Log fold change between mutated and original scores

            Raises:
                ValueError: If task type is not supported
    """
    # Get the task config
    task_config = self.config["task"]
    # Get the predictions
    if task_config.task_type == "binary":
        raw_score = expit(raw_pred)
        mut_score = expit(mut_pred)
    elif task_config.task_type == "multiclass":
        raw_score = softmax(raw_pred)
        mut_score = softmax(mut_pred)
    elif task_config.task_type == "multilabel":
        raw_score = expit(raw_pred)
        mut_score = expit(mut_pred)
    elif task_config.task_type == "regression":
        raw_score = raw_pred
        mut_score = mut_pred
    elif task_config.task_type == "token":
        raw_score = np.argmax(raw_pred, dim=-1)
        mut_score = np.argmax(mut_pred, dim=-1)
    elif task_config.task_type == "generation":
        raw_score = np.array([raw_pred])
        mut_score = np.array([mut_pred])
    elif task_config.task_type == "mask":
        raw_score = np.array([raw_pred])
        mut_score = np.array([mut_pred])
    else:
        raise ValueError(f"Unknown task type: {task_config.task_type}")

    logfc = np.log2(mut_score / raw_score)
    diff = mut_score - raw_score

    return raw_score, mut_score, logfc, diff

Functions