Skip to content

inference/mutagenesis API

dnallm.inference.mutagenesis

In Silico Mutagenesis Analysis Module.

This module provides tools for evaluating the impact of sequence mutations on model predictions, including single nucleotide polymorphisms ( SNPs), deletions, insertions, and other sequence variations.

Classes

Mutagenesis

Mutagenesis(model, tokenizer, config)

Class for evaluating in silico mutagenesis.

This class provides methods to analyze how sequence mutations affect model predictions, including single base substitutions, deletions, and insertions. It can be used to identify important positions in DNA sequences and understand model interpretability.

Attributes:
    model: Fine-tuned model for prediction
    tokenizer: Tokenizer for the model
            config: Configuration object containing task settings and
        inference parameters
    sequences: Dictionary containing original and mutated sequences
    dataloader: DataLoader for batch processing of sequences

Initialize Mutagenesis class.

Parameters:

Name Type Description Default
model Any

Fine-tuned model for making predictions

required
tokenizer Any

Tokenizer for encoding DNA sequences

required
config dict

Configuration object containing task settings and inference parameters

required
Source code in dnallm/inference/mutagenesis.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(self, model: Any, tokenizer: Any, config: dict):
    """Initialize Mutagenesis class.

    Args:
        model: Fine-tuned model for making predictions
        tokenizer: Tokenizer for encoding DNA sequences
        config: Configuration object containing task settings and
            inference parameters
    """

    self.model = model
    self.tokenizer = tokenizer
    self.config = config
    self.sequences = None
Functions
clm_evaluate
clm_evaluate(use_last=False)

Calculate sequence log-probability using causal language modeling.

This method computes the log-probability of each sequence under a causal language model by summing the log probabilities of each token given its preceding context.

Returns:

Type Description
list[float]

List of log-probabilities for each sequence

Source code in dnallm/inference/mutagenesis.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
@torch.no_grad()
def clm_evaluate(self, use_last: bool = False) -> list[float]:
    """Calculate sequence log-probability using causal language modeling.

    This method computes the log-probability of each sequence under a
    causal language model by summing the log probabilities of each token
    given its preceding context.

    Returns:
        List of log-probabilities for each sequence
    """
    all_logprobs = []
    model = self.model
    tokenizer = self.tokenizer
    for seq in tqdm(self.sequences["sequence"], desc="Inferring"):
        toks = tokenizer(
            seq, return_tensors="pt", add_special_tokens=True
        ).to(model.device)
        input_ids = toks["input_ids"]
        outputs = model(**toks)
        logits = outputs.logits  # (1, L, V)

        # shift for causal LM: predict token t given tokens < t
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
        token_logps = log_probs.gather(
            -1, shift_labels.unsqueeze(-1)
        ).squeeze(-1)  # (1, L-1)
        seq_logp = float(token_logps.sum().item())
        # Get last token logp
        if use_last:
            seq_logp = float(token_logps[0, -1].item())
        all_logprobs.append(seq_logp)
    return all_logprobs
evaluate
evaluate(strategy='last')

Evaluate the impact of mutations on model predictions.

This method runs predictions on all mutated sequences and compares them with the original sequence to calculate mutation effects.

Parameters:

Name Type Description Default
strategy str | int

Strategy for selecting the score from the log fold change: "first": Use the first log fold change "last": Use the last log fold change "sum": Use the sum of log fold changes "mean": Use the mean of log fold changes "max": Use the index of the maximum raw score to select the log fold change int: Use the log fold change at the specified index

'last'

Returns:

Type Description
list[dict]

Dictionary containing predictions and metadata for all sequences:

list[dict]
  • 'raw': Original sequence predictions and metadata
list[dict]
  • mutation names: Individual mutation results with scores and log fold changes
Source code in dnallm/inference/mutagenesis.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def evaluate(self, strategy: str | int = "last") -> list[dict]:
    """Evaluate the impact of mutations on model predictions.

    This method runs predictions on all mutated sequences and compares them
    with the original sequence to calculate mutation effects.

    Args:
        strategy: Strategy for selecting the score from the log fold\
            change:
            "first": Use the first log fold change
            "last": Use the last log fold change
            "sum": Use the sum of log fold changes
            "mean": Use the mean of log fold changes
            "max": Use the index of the maximum raw score to select\
                the log fold change
            int: Use the log fold change at the specified index

    Returns:
        Dictionary containing predictions and metadata for all sequences:
        - 'raw': Original sequence predictions and metadata
        - mutation names: Individual mutation results with scores and log
            fold changes
    """
    # Load predictor
    inference_engine = self.get_inference_engine(
        self.model, self.tokenizer
    )
    # Special models list
    sp_models = ["evo1", "evo2", "megadna"]
    sp_model_found = False
    # Do prediction
    all_predictions = {}
    for sp_model in sp_models:
        if sp_model in str(self.model).lower():
            logits = inference_engine.scoring(
                self.dataloader,
                reduce_method=strategy if strategy == "sum" else "mean",
            )
            raw_pred = logits[0]["Score"]
            mut_preds = [logit["Score"] for logit in logits[1:]]
            self.config["task"].task_type = "generation"
            sp_model_found = True
            break
    if not sp_model_found:
        if self.config["task"].task_type == "mask":
            logits = self.mlm_evaluate()
        elif self.config["task"].task_type == "generation":
            logits = self.clm_evaluate()
        else:
            logits, _, _ = inference_engine.batch_infer(
                self.dataloader, do_pred=False
            )
        logits = logits[0] if isinstance(logits, tuple) else logits
        # Get the raw predictions
        raw_pred = (
            logits[0].numpy()
            if isinstance(logits, torch.Tensor)
            else logits[0]
        )
        # Get the mutated predictions
        mut_preds = (
            logits[1:].numpy()
            if isinstance(logits, torch.Tensor)
            else logits[1:]
        )

    # Calculate scores
    def get_score(values: np.ndarray, raw_score: np.ndarray = None):
        # Get final score
        if strategy == "first":
            score = values[0]
        elif strategy == "last":
            score = values[-1]
        elif strategy == "sum":
            score = np.sum(values)
        elif strategy == "mean":
            score = np.mean(values)
        elif strategy == "max":
            idx = raw_score.index(max(raw_score))
            score = values[idx]
        elif isinstance(strategy, int):
            score = values[strategy]
        else:
            score = np.mean(values)
        return score

    for i, mut_pred in tqdm(
        enumerate(mut_preds), desc="Evaluating mutations"
    ):
        # Get the mutated name
        mut_name = self.sequences["name"][i + 1]
        # Get the mutated sequence
        mut_seq = self.sequences["sequence"][i + 1]
        # Compare the predictions
        raw_score, mut_score, logfc, diff = self.pred_comparison(
            raw_pred, mut_pred
        )
        # Store the results
        if "raw" not in all_predictions:
            all_predictions["raw"] = {
                "sequence": self.sequences["sequence"][0],
                "pred": raw_score,
                "logfc": np.zeros(len(raw_score)),
                "diff": np.zeros(len(raw_score)),
                "score": 0.0,
                "logits": get_score(raw_score),
            }
        all_predictions[mut_name] = {
            "sequence": mut_seq,
            "pred": mut_score,
            "logfc": logfc,
            "diff": diff,
        }
        all_predictions[mut_name]["score"] = get_score(logfc, raw_score)
        all_predictions[mut_name]["score2"] = get_score(diff, raw_score)
        all_predictions[mut_name]["logits"] = get_score(
            mut_score, raw_score
        )

    return all_predictions
find_hotspots
find_hotspots(
    preds,
    strategy="maxabs",
    window_size=10,
    percentile_threshold=90.0,
)

Identify hotspot regions from base-level importance scores.

Parameters:

Name Type Description Default
preds Dict[str, Dict]

The raw output from the ISM experiment.

required
strategy str

Strategy to aggregate scores at each position. 'maxabs': Use the score of the mutation with the max absolute effect. 'mean': Use the mean of all mutation scores.

'maxabs'
window_size int

The size of the sliding window to find hotspots.

10
percentile_threshold float

The percentile of window scores to be considered a hotspot.

90.0

Returns:

Type Description
list[tuple[int, int]]

List[Tuple[int, int]]: A list of (start, end) tuples for each hotspot.

Source code in dnallm/inference/mutagenesis.py
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def find_hotspots(
    self,
    preds: dict[str, dict],
    strategy="maxabs",
    window_size: int = 10,
    percentile_threshold: float = 90.0,
) -> list[tuple[int, int]]:
    """
    Identify hotspot regions from base-level importance scores.

    Args:
        preds (Dict[str, Dict]): The raw output from the ISM experiment.
        strategy (str): Strategy to aggregate scores at each position.
                        'maxabs': Use the score of the mutation with
                                  the max absolute effect.
                        'mean': Use the mean of all mutation scores.
        window_size (int): The size of the sliding window to find hotspots.
        percentile_threshold (float): The percentile of window scores to be
                                    considered a hotspot.

    Returns:
        List[Tuple[int, int]]: A list of (start, end) tuples
                               for each hotspot.
    """
    # We care about the magnitude of change, so use absolute scores
    base_scores = self.process_ism_data(preds, strategy=strategy)
    abs_scores = pd.Series(np.abs(base_scores))

    # Calculate rolling average of scores
    rolling_mean = abs_scores.rolling(
        window=window_size, center=True, min_periods=1
    ).mean()

    # Determine the score threshold for a hotspot
    threshold = np.percentile(rolling_mean, percentile_threshold)

    # Find regions above the threshold
    hotspot_mask = rolling_mean >= threshold

    # Find contiguous blocks of 'True'
    hotspots = []
    start = -1
    for i, is_hot in enumerate(hotspot_mask):
        if is_hot and start == -1:
            start = i
        elif not is_hot and start != -1:
            hotspots.append((start, i))
            start = -1
    if start != -1:
        hotspots.append((start, len(hotspot_mask)))

    # Return list of hotspot regions with window size
    hotspots_regioned = []
    for i, (start, end) in enumerate(hotspots):
        mid = (start + end) // 2
        window_start = min(max(0, mid - window_size // 2), start)
        window_end = max(
            min(mid + window_size // 2, len(base_scores)), end
        )
        # if the window is within last detected hotspot,
        # skip the current one
        if i > 0:
            prev_start, prev_end = hotspots_regioned[-1]
            if window_start >= prev_start and window_end <= prev_end:
                continue
        hotspots_regioned.append((window_start, window_end))
    self.hotspots = hotspots_regioned

    return hotspots_regioned
get_inference_engine
get_inference_engine(model, tokenizer)

Create an inference engine object for the model.

Parameters:

Name Type Description Default
model Any

The model to be used for inference

required
tokenizer Any

The tokenizer to be used for encoding sequences

required

Returns:

Name Type Description
DNAInference DNAInference

The inference engine object configured with the given model and tokenizer

Source code in dnallm/inference/mutagenesis.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def get_inference_engine(self, model: Any, tokenizer: Any) -> DNAInference:
    """Create an inference engine object for the model.

    Args:
        model: The model to be used for inference
        tokenizer: The tokenizer to be used for encoding sequences

    Returns:
        DNAInference: The inference engine object configured with the given
            model and tokenizer
    """

    inference_engine = DNAInference(
        model=model, tokenizer=tokenizer, config=self.config
    )

    return inference_engine
mlm_evaluate
mlm_evaluate(return_sum=True)

Calculate pseudo-log-likelihood score using masked token prediction.

This method computes the pseudo-log-likelihood (PLL) score for each sequence by iteratively masking each token and predicting it using the model. The PLL score is the sum of the log probabilities of the true tokens given the masked context.

Returns:

Type Description
list[float]

List of pseudo-log-likelihood scores for each sequence

Source code in dnallm/inference/mutagenesis.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
@torch.no_grad()
def mlm_evaluate(self, return_sum: bool = True) -> list[float]:
    """Calculate pseudo-log-likelihood score using masked token prediction.

    This method computes the pseudo-log-likelihood (PLL) score for each
    sequence by iteratively masking each token and predicting it using the
    model. The PLL score is the sum of the log probabilities of the true
    tokens given the masked context.

    Returns:
        List of pseudo-log-likelihood scores for each sequence
    """
    all_logprobs = []
    model = self.model
    tokenizer = self.tokenizer
    for seq in tqdm(self.sequences["sequence"], desc="Inferring"):
        toks = tokenizer(
            seq, return_tensors="pt", add_special_tokens=True
        ).to(model.device)
        input_ids = toks["input_ids"].clone()
        seq_len = input_ids.size(1)
        total = 0.0
        p_values = []

        for i in range(seq_len):
            tok_id = input_ids[0, i].item()
            if tok_id in tokenizer.all_special_ids:
                continue
            masked = input_ids.clone()
            masked[0, i] = tokenizer.mask_token_id
            masked_inputs = {
                "input_ids": masked,
                # "attention_mask": toks["attention_mask"],
            }
            outputs = model(**masked_inputs)
            logits = outputs.logits
            logp = torch.nn.functional.log_softmax(logits[0, i], dim=-1)
            if return_sum:
                total += float(logp[tok_id].item())
            else:
                try:
                    token = tokenizer.decode([tok_id])[0]
                except KeyError:
                    token = tokenizer.convert_ids_to_tokens(tok_id)
                p_values.append((token, float(logp[tok_id].item())))
        if return_sum:
            all_logprobs.append(total)
        else:
            all_logprobs.append(p_values)
    return all_logprobs
mutate_sequence
mutate_sequence(
    sequence,
    batch_size=1,
    replace_mut=True,
    include_n=False,
    delete_size=0,
    cut_size=0,
    fill_gap=False,
    insert_seq=None,
    lowercase=False,
    do_encode=True,
)

Generate dataset from sequences with various mutation types.

This method creates mutated versions of the input sequence including: - Single base substitutions (A, C, G, T, optionally N) - Deletions of specified size - Insertions of specified sequences - Case transformations

Parameters:

Name Type Description Default
sequence str

Single sequence for mutagenesis

required
batch_size int

Batch size for DataLoader

1
replace_mut bool

Whether to perform single base substitutions

True
include_n bool

Whether to include N base in substitutions

False
delete_size int

Size of deletions to create (0 for no deletions)

0
fill_gap bool

Whether to fill deletion gaps with N bases

False
insert_seq str | None

Sequence to insert at various positions

None
lowercase bool

Whether to convert sequences to lowercase

False
do_encode bool

Whether to encode sequences for the model

True

Returns:

Type Description
None

None (modifies internal state)

Source code in dnallm/inference/mutagenesis.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def mutate_sequence(
    self,
    sequence: str,
    batch_size: int = 1,
    replace_mut: bool = True,
    include_n: bool = False,
    delete_size: int = 0,
    cut_size: int = 0,
    fill_gap: bool = False,
    insert_seq: str | None = None,
    lowercase: bool = False,
    do_encode: bool = True,
) -> None:
    """Generate dataset from sequences with various mutation types.

    This method creates mutated versions of the input sequence including:
    - Single base substitutions (A, C, G, T, optionally N)
    - Deletions of specified size
    - Insertions of specified sequences
    - Case transformations

    Args:
        sequence: Single sequence for mutagenesis
        batch_size: Batch size for DataLoader
        replace_mut: Whether to perform single base substitutions
        include_n: Whether to include N base in substitutions
        delete_size: Size of deletions to create (0 for no deletions)
        fill_gap: Whether to fill deletion gaps with N bases
        insert_seq: Sequence to insert at various positions
        lowercase: Whether to convert sequences to lowercase
        do_encode: Whether to encode sequences for the model

    Returns:
        None (modifies internal state)
    """
    # Get the inference config
    pred_config = self.config["inference"]
    # Define the dataset
    sequences = {"name": ["raw"], "sequence": [sequence]}
    # Create mutated sequences
    if replace_mut:
        if include_n:
            base_map = ["A", "C", "G", "T", "N"]
        else:
            base_map = ["A", "C", "G", "T"]
        # Mutate sequence
        for i, base in enumerate(sequence):
            for mut_base in base_map:
                if base != mut_base:
                    name = f"mut_{i}_{base}_{mut_base}"
                    mutated_sequence = (
                        sequence[:i] + mut_base + sequence[i + 1 :]
                    )
                    sequences["name"].append(name)
                    sequences["sequence"].append(mutated_sequence)
    # Delete mutations
    if delete_size > 0:
        for i in range(len(sequence) - delete_size + 1):
            name = f"del_{i}_{delete_size}"
            if fill_gap:
                mutated_sequence = (
                    sequence[:i]
                    + "N" * delete_size
                    + sequence[i + delete_size :]
                )
            else:
                mutated_sequence = (
                    sequence[:i] + sequence[i + delete_size :]
                )
            sequences["name"].append(name)
            sequences["sequence"].append(mutated_sequence)
    # Insert mutations
    if insert_seq is not None:
        for i in range(len(sequence) + 1):
            name = f"ins_{i}_{insert_seq}"
            mutated_sequence = sequence[:i] + insert_seq + sequence[i:]
            sequences["name"].append(name)
            sequences["sequence"].append(mutated_sequence)
    # Cut mutations
    if cut_size != 0:
        step = abs(cut_size)
        for i in range(0, len(sequence) - step + 1, step):
            name = f"cut_{i}_{cut_size}"
            if cut_size > 0:
                mutated_sequence = sequence[i:]
            else:
                mutated_sequence = sequence[: len(sequence) - i]
            sequences["name"].append(name)
            sequences["sequence"].append(mutated_sequence)
    # Lowercase sequences
    if lowercase:
        sequences["sequence"] = [
            seq.lower() for seq in sequences["sequence"]
        ]
    # Create dataset
    if len(sequences["sequence"]) > 0:
        ds = Dataset.from_dict(sequences)
        dataset = DNADataset(
            ds, self.tokenizer, max_length=pred_config.max_length
        )
        self.sequences = sequences
    # Encode sequences
    if do_encode:
        dataset.encode_sequences(remove_unused_columns=True)
    # Create DataLoader
    if batch_size <= 1:
        batch_size = pred_config.batch_size
    self.dataloader: DataLoader = DataLoader(
        dataset, batch_size=batch_size, num_workers=pred_config.num_workers
    )
plot
plot(preds, show_score=False, save_path=None)

Plot the mutagenesis analysis results.

    This method generates visualizations of mutation effects,
typically as heatmaps,
    bar charts and
line plots showing how different mutations affect model predictions

at various positions.

Parameters:

Name Type Description Default
preds dict

Dictionary containing model predicted scores and metadata

required
show_score bool

Whether to show the score values on the plot save_path: Path to save the plot. If None, plot will be shown interactively

False

Returns:

Type Description
None

Plot object

Source code in dnallm/inference/mutagenesis.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
def plot(
    self,
    preds: dict,
    show_score: bool = False,
    save_path: str | None = None,
) -> None:
    """Plot the mutagenesis analysis results.

            This method generates visualizations of mutation effects,
        typically as heatmaps,
            bar charts and
        line plots showing how different mutations affect model predictions
    at various positions.

    Args:
        preds: Dictionary containing model predicted scores and metadata
        show_score: Whether to show the score values on the plot
                    save_path: Path to save the plot. If None,
            plot will be shown interactively

    Returns:
        Plot object
    """
    if save_path:
        suffix = os.path.splitext(save_path)[-1]
        if suffix:
            outfile = save_path
        else:
            outfile = os.path.join(save_path, ".pdf")
    else:
        outfile = None
    # Plot heatmap
    pmut = plot_muts(preds, show_score=show_score, save_path=outfile)
    return pmut
pred_comparison
pred_comparison(raw_pred, mut_pred)

Compare raw and mutated predictions.

This method calculates the difference between predictions on the original sequence and mutated sequences, providing insights into mutation effects.

    Args:
        raw_pred: Raw predictions from the original sequence
        mut_pred: Predictions from the mutated sequence

    Returns:
        Tuple containing (raw_score, mut_score, logfc):
- raw_score: Processed scores from original sequence
- mut_score: Processed scores from mutated sequence
- logfc: Log fold change between mutated and original scores

    Raises:
        ValueError: If task type is not supported
Source code in dnallm/inference/mutagenesis.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def pred_comparison(self, raw_pred, mut_pred):
    """Compare raw and mutated predictions.

    This method calculates the difference between predictions on the
    original sequence and mutated sequences, providing insights into
    mutation effects.

            Args:
                raw_pred: Raw predictions from the original sequence
                mut_pred: Predictions from the mutated sequence

            Returns:
                Tuple containing (raw_score, mut_score, logfc):
        - raw_score: Processed scores from original sequence
        - mut_score: Processed scores from mutated sequence
        - logfc: Log fold change between mutated and original scores

            Raises:
                ValueError: If task type is not supported
    """
    # Get the task config
    task_config = self.config["task"]
    # Get the predictions
    if task_config.task_type == "binary":
        raw_score = expit(raw_pred)
        mut_score = expit(mut_pred)
    elif task_config.task_type == "multiclass":
        raw_score = softmax(raw_pred)
        mut_score = softmax(mut_pred)
    elif task_config.task_type == "multilabel":
        raw_score = expit(raw_pred)
        mut_score = expit(mut_pred)
    elif task_config.task_type == "regression":
        raw_score = raw_pred
        mut_score = mut_pred
    elif task_config.task_type == "token":
        raw_score = np.argmax(raw_pred, dim=-1)
        mut_score = np.argmax(mut_pred, dim=-1)
    elif task_config.task_type == "generation":
        raw_score = np.array([raw_pred])
        mut_score = np.array([mut_pred])
    elif task_config.task_type == "mask":
        raw_score = np.array([raw_pred])
        mut_score = np.array([mut_pred])
    else:
        raise ValueError(f"Unknown task type: {task_config.task_type}")

    logfc = np.log2(mut_score / raw_score)
    diff = mut_score - raw_score

    return raw_score, mut_score, logfc, diff
prepare_tfmodisco_inputs
prepare_tfmodisco_inputs(ism_results_list)

Prepares inputs required for a TF-MoDISco run from a list of ISM results.

Source code in dnallm/inference/mutagenesis.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def prepare_tfmodisco_inputs(
    self, ism_results_list: list[dict[str, dict]]
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Prepares inputs required for a TF-MoDISco run from
    a list of ISM results.
    """
    print("Preparing inputs for TF-MoDISco...")
    acgt = ["A", "C", "G", "T"]
    acgt_to_idx = {base: i for i, base in enumerate(acgt)}

    all_one_hot, all_hyp_scores = [], []

    for ism_results in ism_results_list:
        raw_seq = ism_results["raw"]["sequence"].upper()
        seq_len = len(raw_seq)

        one_hot = np.zeros((seq_len, 4))
        for i, base in enumerate(raw_seq):
            idx = acgt_to_idx.get(base)
            if idx is not None:
                one_hot[i, idx] = 1
        all_one_hot.append(one_hot)

        hyp_scores = np.zeros((seq_len, 4))
        for key, value in ism_results.items():
            if key.startswith("mut_"):
                parts = key.split("_")
                pos, _, mut_base = int(parts[1]), parts[2], parts[-1]
                if mut_base in acgt_to_idx:
                    hyp_scores[pos, acgt_to_idx[mut_base]] = value["score"]
        all_hyp_scores.append(hyp_scores)

    one_hot_seqs = np.array(all_one_hot)
    hyp_scores = np.array(all_hyp_scores)
    contrib_scores = hyp_scores * one_hot_seqs
    return one_hot_seqs, hyp_scores, contrib_scores
process_ism_data
process_ism_data(ism_results, strategy='maxabs')

Process raw ISM result dictionary to get a single importance score per base.

Parameters:

Name Type Description Default
ism_results Dict[str, Dict]

The raw output from the ISM experiment.

required
strategy str

Strategy to aggregate scores at each position. 'maxabs': Use the score of the mutation with the max absolute effect. 'mean': Use the mean of all mutation scores.

'maxabs'

Returns:

Type Description
ndarray

np.ndarray: A 1D array of importance scores, one per base pair.

Source code in dnallm/inference/mutagenesis.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def process_ism_data(
    self,
    ism_results: dict[str, dict],
    strategy: str = "maxabs",
) -> np.ndarray:
    """
    Process raw ISM result dictionary to get a single importance score
    per base.

    Args:
        ism_results (Dict[str, Dict]): The raw output from
            the ISM experiment.
        strategy (str): Strategy to aggregate scores at each position.
                        'maxabs': Use the score of the mutation with
                                  the max absolute effect.
                        'mean': Use the mean of all mutation scores.

    Returns:
        np.ndarray: A 1D array of importance scores, one per base pair.
    """
    raw_seq = ism_results["raw"]["sequence"]
    seq_len = len(raw_seq)
    base_scores = np.zeros(seq_len)

    # Group mutations by position
    pos_muts = {}
    for key, value in ism_results.items():
        if key.startswith("mut_"):
            parts = key.split("_")
            pos = int(parts[1])
            if pos not in pos_muts:
                pos_muts[pos] = []
            pos_muts[pos].append(value["score"])

    # Apply aggregation strategy
    for pos, scores in pos_muts.items():
        if not scores:
            continue
        if strategy in ["maxabs", "min", "max"]:
            max_abs_idx = np.argmax(np.abs(scores))
            base_scores[pos] = scores[max_abs_idx]
        elif strategy == "mean":
            base_scores[pos] = np.mean(scores)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

    return base_scores

Functions