Skip to content

inference/predictor API

DNA Language Model Inference Module.

This module implements core model inference functionality, including:

  1. DNAPredictor class
  2. Model loading and initialization
  3. Batch sequence prediction
  4. Result post-processing
  5. Device management
  6. Half-precision inference support

  7. Core features:

  8. Model state management
  9. Batch prediction
  10. Result merging
  11. Prediction result saving
  12. Memory optimization

  13. Inference optimization:

  14. Batch parallelization
  15. GPU acceleration
  16. Half-precision computation
  17. Memory efficiency optimization
Example
predictor = DNAPredictor(
    model=model,
    tokenizer=tokenizer,
    config=config
)
results = predictor.predict(sequences)

DNAPredictor

DNA sequence predictor using fine-tuned models.

This class provides comprehensive functionality for making predictions using DNA language models. It handles model loading, inference, result processing, and various output formats including hidden states and attention weights for model interpretability.

Attributes:

Name Type Description
model

Fine-tuned model instance for inference

tokenizer

Tokenizer for encoding DNA sequences

task_config

Configuration object containing task settings

pred_config

Configuration object containing inference parameters

device

Device (CPU/GPU/MPS) for model inference

sequences

List of input sequences

labels

List of true labels (if available)

embeddings

Dictionary containing hidden states and attention weights

Source code in dnallm/inference/predictor.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
class DNAPredictor:
    """DNA sequence predictor using fine-tuned models.

    This class provides comprehensive functionality for making predictions using DNA language models.
    It handles model loading, inference, result processing, and various output formats including
    hidden states and attention weights for model interpretability.

    Attributes:
        model: Fine-tuned model instance for inference
        tokenizer: Tokenizer for encoding DNA sequences
        task_config: Configuration object containing task settings
        pred_config: Configuration object containing inference parameters
        device: Device (CPU/GPU/MPS) for model inference
        sequences: List of input sequences
        labels: List of true labels (if available)
        embeddings: Dictionary containing hidden states and attention weights
    """

    def __init__(
        self,
        model: Any,
        tokenizer: Any,
        config: dict
    ):
        """Initialize the predictor.

        Args:
            model: Fine-tuned model instance for inference
            tokenizer: Tokenizer for encoding DNA sequences
            config: Configuration dictionary containing task settings and inference parameters
        """

        self.model = model
        self.tokenizer = tokenizer
        self.task_config = config['task']
        self.pred_config = config['inference']
        self.device = self._get_device()
        if model:
            self.model.to(self.device)
            print(f"Use device: {self.device}")
        self.sequences = []
        self.labels = []

    def _get_device(self) -> torch.device:
        """Get the appropriate device for model inference.

        This method automatically detects and selects the best available device for inference,
        supporting CPU, CUDA (NVIDIA), MPS (Apple Silicon), ROCm (AMD), TPU, and XPU (Intel).

        Returns:
            torch.device: The device to use for model inference

        Raises:
            ValueError: If the specified device type is not supported
        """
        # Get the device type
        device = self.pred_config.device.lower()
        if device == 'cpu':
            return torch.device('cpu')
        elif device in ['cuda', 'nvidia']:
            if not torch.cuda.is_available():
                warnings.warn("CUDA is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('cuda')
        elif device in ['mps', 'apple', 'mac']:
            if not torch.backends.mps.is_available():
                warnings.warn("MPS is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('mps')
        elif device in ['rocm', 'amd']:
            if not torch.cuda.is_available():
                warnings.warn("ROCm is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('cuda')
        elif device == ['tpu', 'xla', 'google']:
            try:
                import torch_xla.core.xla_model as xm
                return torch.device('xla')
            except:
                warnings.warn("TPU is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
        elif device == ['xpu', 'intel']:
            if not torch.xpu.is_available():
                warnings.warn("XPU is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('xpu')
        elif device == 'auto':
            if torch.cuda.is_available():
                return torch.device('cuda')
            elif torch.backends.mps.is_available():
                return torch.device('mps')
            elif torch.xpu.is_available():
                return torch.device('xpu')
            else:
                return torch.device('cpu')
        else:
            raise ValueError(f"Unsupported device type: {device}")

    def generate_dataset(self, seq_or_path: Union[str, List[str]], batch_size: int = 1,
                         seq_col: str = "sequence", label_col: str = "labels",
                         sep: str = None, fasta_sep: str = "|",
                         multi_label_sep: Union[str, None] = None,
                         uppercase: bool = False, lowercase: bool = False,
                         keep_seqs: bool = True, do_encode: bool = True) -> Tuple[DNADataset, DataLoader]:
        """Generate dataset from sequences or file path.

        This method creates a DNADataset and DataLoader from either a list of sequences
        or a file path, supporting various file formats and preprocessing options.

        Args:
            seq_or_path: Single sequence, list of sequences, or path to a file containing sequences
            batch_size: Batch size for DataLoader
            seq_col: Column name for sequences in the file
            label_col: Column name for labels in the file
            sep: Delimiter for CSV, TSV, or TXT files
            fasta_sep: Delimiter for FASTA files
            multi_label_sep: Delimiter for multi-label sequences
            uppercase: Whether to convert sequences to uppercase
            lowercase: Whether to convert sequences to lowercase
            keep_seqs: Whether to keep sequences in the dataset for later use
            do_encode: Whether to encode sequences for the model

        Returns:
            Tuple containing:
                - DNADataset: Dataset object with sequences and labels
                - DataLoader: DataLoader object for batch processing

        Raises:
            ValueError: If input is neither a file path nor a list of sequences
        """
        if isinstance(seq_or_path, str):
            suffix = seq_or_path.split(".")[-1]
            if suffix and os.path.isfile(seq_or_path):
                sequences = []
                dataset = DNADataset.load_local_data(seq_or_path, seq_col=seq_col, label_col=label_col,
                                                     sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                                     tokenizer=self.tokenizer, max_length=self.pred_config.max_length)
            else:
                sequences = [seq_or_path]
        elif isinstance(seq_or_path, list):
            sequences = seq_or_path
        else:
            raise ValueError("Input should be a file path or a list of sequences.")
        if len(sequences) > 0:
            ds = Dataset.from_dict({"sequence": sequences})
            dataset = DNADataset(ds, self.tokenizer, max_length=self.pred_config.max_length)
        # If labels are provided, keep labels
        if keep_seqs:
            self.sequences = dataset.dataset["sequence"]
        # Encode sequences
        if do_encode:
            task_type = self.task_config.task_type
            dataset.encode_sequences(remove_unused_columns=True, task=task_type, uppercase=uppercase, lowercase=lowercase)
        if "labels" in dataset.dataset.features:
            self.labels = dataset.dataset["labels"]
        # Create DataLoader
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=self.pred_config.num_workers
        )

        return dataset, dataloader

    def logits_to_preds(self, logits: torch.Tensor) -> Tuple[torch.Tensor, List]:
        """Convert model logits to predictions and human-readable labels.

        This method processes raw model outputs based on the task type to generate
        appropriate predictions and convert them to human-readable labels.

        Args:
            logits: Model output logits tensor

        Returns:
            Tuple containing:
                - torch.Tensor: Model predictions (probabilities or raw values)
                - List: Human-readable labels corresponding to predictions

        Raises:
            ValueError: If task type is not supported
        """
        # Get task type and threshold from config
        task_type = self.task_config.task_type
        threshold = self.task_config.threshold
        label_names = self.task_config.label_names
        # Convert logits to predictions based on task type
        if task_type == "binary":
            probs = torch.softmax(logits, dim=-1)
            preds = (probs[:, 1] > threshold).long()
            labels = [label_names[pred] for pred in preds]
        elif task_type == "multiclass":
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(probs, dim=-1)
            labels = [label_names[pred] for pred in preds]
        elif task_type == "multilabel":
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).long()
            labels = []
            for pred in preds:
                label = [label_names[i] for i in range(len(pred)) if pred[i] == 1]
                labels.append(label)
        elif task_type == "regression":
            preds = logits.squeeze(-1)
            probs = preds
            labels = preds.tolist()
        elif task_type == "token":
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            labels = []
            for pred in preds:
                label = [label_names[pred[i]] for i in range(len(pred))]
                labels.append(label)
        else:
            raise ValueError(f"Unsupported task type: {task_type}")
        return probs, labels

    def format_output(self, predictions: Tuple[torch.Tensor, List]) -> Dict:
        """Format output predictions into a structured dictionary.

        This method converts raw predictions into a user-friendly format with
        sequences, labels, and confidence scores.

        Args:
            predictions: Tuple containing (probabilities, labels)

        Returns:
            Dictionary containing formatted predictions with structure:
            {index: {'sequence': str, 'label': str/list, 'scores': dict/list}}
        """
        # Get task type from config
        task_type = self.task_config.task_type
        formatted_predictions = {}
        probs, labels = predictions
        probs = probs.numpy().tolist()
        keep_seqs = True if len(self.sequences) else False
        label_names = self.task_config.label_names
        for i, label in enumerate(labels):
            prob = probs[i]
            formatted_predictions[i] = {
                'sequence': self.sequences[i] if keep_seqs else '',
                'label': label,
                'scores': {label_names[j]: p for j, p in enumerate(prob)} if task_type != "token"
                          else [max(x) for x in prob],
            }
        return formatted_predictions

    @torch.no_grad()
    def batch_predict(self, dataloader: DataLoader, do_pred: bool = True,
                      output_hidden_states: bool = False,
                      output_attentions: bool = False) -> Tuple[torch.Tensor, Optional[Dict], Dict]:
        """Perform batch prediction on sequences.

        This method runs inference on batches of sequences and optionally extracts
        hidden states and attention weights for model interpretability.

        Args:
            dataloader: DataLoader object containing sequences for inference
            do_pred: Whether to convert logits to predictions
            output_hidden_states: Whether to output hidden states from all layers
            output_attentions: Whether to output attention weights from all layers

        Returns:
            Tuple containing:
                - torch.Tensor: All logits from the model
                - Optional[Dict]: Predictions dictionary if do_pred=True, otherwise None
                - Dict: Embeddings dictionary containing hidden states and/or attention weights

        Note:
            Setting output_hidden_states or output_attentions to True will consume
            significant memory, especially for long sequences or large models.
        """
        # Set model to evaluation mode
        self.model.eval()
        all_logits = []
        # Whether or not to output hidden states
        params = None
        embeddings = {}
        if output_hidden_states:
            import inspect
            sig = inspect.signature(self.model.forward)
            params = sig.parameters
            if 'output_hidden_states' in params:
                self.model.config.output_hidden_states = True
            embeddings['hidden_states'] = None
            embeddings['attention_mask'] = []
            embeddings['labels'] = []
        if output_attentions:
            if not params:
                import inspect
                sig = inspect.signature(self.model.forward)
                params = sig.parameters
            if 'output_attentions' in params:
                self.model.config.output_attentions = True
            embeddings['attentions'] = None
        # Iterate over batches
        for batch in tqdm(dataloader, desc="Predicting"):
            inputs = {k: v.to(self.device) for k, v in batch.items()}
            if self.pred_config.use_fp16:
                self.model = self.model.half()
                with torch.amp.autocast('cuda'):
                    outputs = self.model(**inputs)
            else:
                outputs = self.model(**inputs)
            # Get logits
            logits = outputs.logits.cpu().detach()
            all_logits.append(logits)
            # Get hidden states
            if output_hidden_states:
                hidden_states = [h.cpu().detach() for h in outputs.hidden_states] if hasattr(outputs, 'hidden_states') else None
                if embeddings['hidden_states'] is None:
                    embeddings['hidden_states'] = [[h] for h in hidden_states]
                else:
                    for i, h in enumerate(hidden_states):
                        embeddings['hidden_states'][i].append(h)
                attention_mask = inputs['attention_mask'].cpu().detach() if 'attention_mask' in inputs else None
                embeddings['attention_mask'].append(attention_mask)
                labels = inputs['labels'].cpu().detach() if 'labels' in inputs else None
                embeddings['labels'].append(labels)
            # Get attentions
            if output_attentions:
                attentions = [a.cpu().detach() for a in outputs.attentions] if hasattr(outputs, 'attentions') else None
                if attentions:
                    if embeddings['attentions'] is None:
                        embeddings['attentions'] = [[a] for a in attentions]
                    else:
                        for i, a in enumerate(attentions):
                            embeddings['attentions'][i].append(a)
        # Concatenate logits
        all_logits = torch.cat(all_logits, dim=0)
        if output_hidden_states:
            if embeddings['hidden_states']:
                embeddings['hidden_states'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['hidden_states'])
            if embeddings['attention_mask']:
                embeddings['attention_mask'] = torch.cat(embeddings['attention_mask'], dim=0)
            if embeddings['labels']:
                embeddings['labels'] = torch.cat(embeddings['labels'], dim=0)
        if output_attentions:
            if embeddings['attentions']:
                embeddings['attentions'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['attentions'])
        # Get predictions
        predictions = None
        if do_pred:
            predictions = self.logits_to_preds(all_logits)
            predictions = self.format_output(predictions)
        return all_logits, predictions, embeddings

    def predict_seqs(self, sequences: Union[str, List[str]],
                     evaluate: bool = False,
                     output_hidden_states: bool = False,
                     output_attentions: bool = False,
                     save_to_file: bool = False) -> Union[Dict, Tuple[Dict, Dict]]:
        """Predict for a list of sequences.

        This method provides a convenient interface for predicting on sequences,
        with optional evaluation and saving capabilities.

        Args:
            sequences: Single sequence or list of sequences for prediction
            evaluate: Whether to evaluate predictions against true labels
            output_hidden_states: Whether to output hidden states for visualization
            output_attentions: Whether to output attention weights for visualization
            save_to_file: Whether to save predictions to output directory

        Returns:
            Either:
                - Dict: Dictionary containing predictions
                - Tuple[Dict, Dict]: (predictions, metrics) if evaluate=True

        Note:
            Evaluation requires that labels are available in the dataset
        """
        # Get dataset and dataloader from sequences
        _, dataloader = self.generate_dataset(sequences, batch_size=self.pred_config.batch_size)
        # Do batch prediction
        logits, predictions, embeddings = self.batch_predict(dataloader,
                                                             output_hidden_states=output_hidden_states,
                                                             output_attentions=output_attentions)
        # Keep hidden states
        if output_hidden_states or output_attentions:
            self.embeddings = embeddings
        # Save predictions
        if save_to_file and self.pred_config.output_dir:
            save_predictions(predictions, Path(self.pred_config.output_dir))
        # Do evaluation
        if len(self.labels) == len(logits) and evaluate:
            metrics = self.calculate_metrics(logits, self.labels)
            metrics_save = dict(metrics)
            if 'curve' in metrics_save:
                del metrics_save['curve']
            if 'scatter' in metrics_save:
                del metrics_save['scatter']
            if save_to_file and self.pred_config.output_dir:
                save_metrics(metrics_save, Path(self.pred_config.output_dir))
            return predictions, metrics

        return predictions


    def predict_file(self, file_path: str, evaluate: bool = False,
                     output_hidden_states: bool = False,
                     output_attentions: bool = False,
                     seq_col: str = "sequence", label_col: str = "labels",
                     sep: str = None, fasta_sep: str = "|",
                     multi_label_sep: Union[str, None] = None,
                     uppercase: bool = False, lowercase: bool = False,
                     save_to_file: bool = False, plot_metrics: bool = False) -> Union[Dict, Tuple[Dict, Dict]]:
        """Predict from a file containing sequences.

        This method loads sequences from a file and performs prediction,
        with optional evaluation, visualization, and saving capabilities.

        Args:
            file_path: Path to the file containing sequences
            evaluate: Whether to evaluate predictions against true labels
            output_hidden_states: Whether to output hidden states for visualization
            output_attentions: Whether to output attention weights for visualization
            seq_col: Column name for sequences in the file
            label_col: Column name for labels in the file
            sep: Delimiter for CSV, TSV, or TXT files
            fasta_sep: Delimiter for FASTA files
            multi_label_sep: Delimiter for multi-label sequences
            uppercase: Whether to convert sequences to uppercase
            lowercase: Whether to convert sequences to lowercase
            save_to_file: Whether to save predictions and metrics to output directory
            plot_metrics: Whether to generate metric plots

        Returns:
            Either:
                - Dict: Dictionary containing predictions
                - Tuple[Dict, Dict]: (predictions, metrics) if evaluate=True

        Note:
            Setting output_attentions=True may consume significant memory
        """
        # Get dataset and dataloader from file
        _, dataloader = self.generate_dataset(file_path, seq_col=seq_col, label_col=label_col,
                                              sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                              uppercase=uppercase, lowercase=lowercase,
                                              batch_size=self.pred_config.batch_size)
        # Do batch prediction
        if output_attentions:
            warnings.warn("Cautions: output_attentions may consume a lot of memory.\n")
        logits, predictions, embeddings = self.batch_predict(dataloader,
                                                             output_hidden_states=output_hidden_states,
                                                             output_attentions=output_attentions)
        # Keep hidden states
        if output_hidden_states or output_attentions:
            self.embeddings = embeddings
        # Save predictions
        if save_to_file and self.pred_config.output_dir:
            save_predictions(predictions, Path(self.pred_config.output_dir))
        # Do evaluation
        if len(self.labels) == len(logits) and evaluate:
            metrics = self.calculate_metrics(logits, self.labels, plot=plot_metrics)
            metrics_save = dict(metrics)
            if 'curve' in metrics_save:
                del metrics_save['curve']
            if 'scatter' in metrics_save:
                del metrics_save['scatter']
            if save_to_file and self.pred_config.output_dir:
                save_metrics(metrics, Path(self.pred_config.output_dir))
            # Whether to plot metrics
            if plot_metrics:
                return predictions, metrics
            else:
                return predictions, metrics_save

        return predictions

    def calculate_metrics(self, logits: Union[List, torch.Tensor],
                          labels: Union[List, torch.Tensor], plot: bool = False) -> Dict:
        """Calculate evaluation metrics for model predictions.

        This method computes task-specific evaluation metrics using the configured
        metrics computation module.

        Args:
            logits: Model predictions (logits or probabilities)
            labels: True labels for evaluation
            plot: Whether to generate metric plots

        Returns:
            Dictionary containing evaluation metrics for the task
        """
        # Calculate metrics based on task type
        compute_metrics = Metrics(self.task_config, plot=plot)
        metrics = compute_metrics((logits, labels))

        return metrics

    def plot_attentions(self, seq_idx: int = 0, layer: int = -1, head: int = -1,
                        width: int = 800, height: int = 800,
                        save_path: Optional[str] = None) -> Optional[Any]:
        """Plot attention map visualization.

        This method creates a heatmap visualization of attention weights between tokens
        in a sequence, showing how the model attends to different parts of the input.

        Args:
            seq_idx: Index of the sequence to plot, default 0
            layer: Layer index to visualize, default -1 (last layer)
            head: Attention head index to visualize, default -1 (last head)
            width: Width of the plot
            height: Height of the plot
            save_path: Path to save the plot. If None, plot will be shown interactively

        Returns:
            Attention map visualization if available, otherwise None

        Note:
            This method requires that attention weights were collected during inference
            by setting output_attentions=True in prediction methods
        """
        if hasattr(self, 'embeddings'):
            attentions = self.embeddings['attentions']
            if save_path:
                suffix = os.path.splitext(save_path)[-1]
                if suffix:
                    heatmap = save_path.replace(suffix, "_heatmap" + suffix)
                else:
                    heatmap = os.path.join(save_path, "heatmap.pdf")
            else:
                heatmap = None
            # Plot attention map
            attn_map = plot_attention_map(attentions, self.sequences, self.tokenizer,
                                          seq_idx=seq_idx, layer=layer, head=head,
                                          width=width, height=height,
                                          save_path=heatmap)
            return attn_map
        else:
            print("No attention weights available to plot.")

    def plot_hidden_states(self, reducer: str = "t-SNE",
                           ncols: int = 4, width: int = 300, height: int = 300,
                           save_path: Optional[str] = None) -> Optional[Any]:
        """Visualize embeddings using dimensionality reduction.

        This method creates 2D visualizations of high-dimensional embeddings from
        different model layers using PCA, t-SNE, or UMAP dimensionality reduction.

        Args:
            reducer: Dimensionality reduction method to use ('PCA', 't-SNE', 'UMAP')
            ncols: Number of columns in the plot grid
            width: Width of each plot
            height: Height of each plot
            save_path: Path to save the plot. If None, plot will be shown interactively

        Returns:
            Embedding visualization if available, otherwise None

        Note:
            This method requires that hidden states were collected during inference
            by setting output_hidden_states=True in prediction methods
        """
        if hasattr(self, 'embeddings'):
            hidden_states = self.embeddings['hidden_states']
            attention_mask = torch.unsqueeze(self.embeddings['attention_mask'], dim=-1)
            labels = self.embeddings['labels']
            if save_path:
                suffix = os.path.splitext(save_path)[-1]
                if suffix:
                    embedding = save_path.replace(suffix, "_embedding" + suffix)
                else:
                    embedding = os.path.join(save_path, "embedding.pdf")
            else:
                embedding = None
            # Plot hidden states
            label_names = self.task_config.label_names
            embeddings_vis = plot_embeddings(hidden_states, attention_mask, reducer=reducer,
                                             labels=labels, label_names=label_names,
                                             ncols=ncols, width=width, height=height,
                                             save_path=embedding)
            return embeddings_vis
        else:
            print("No hidden states available to plot.")

__init__(model, tokenizer, config)

Initialize the predictor.

Parameters:

Name Type Description Default
model Any

Fine-tuned model instance for inference

required
tokenizer Any

Tokenizer for encoding DNA sequences

required
config dict

Configuration dictionary containing task settings and inference parameters

required
Source code in dnallm/inference/predictor.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    model: Any,
    tokenizer: Any,
    config: dict
):
    """Initialize the predictor.

    Args:
        model: Fine-tuned model instance for inference
        tokenizer: Tokenizer for encoding DNA sequences
        config: Configuration dictionary containing task settings and inference parameters
    """

    self.model = model
    self.tokenizer = tokenizer
    self.task_config = config['task']
    self.pred_config = config['inference']
    self.device = self._get_device()
    if model:
        self.model.to(self.device)
        print(f"Use device: {self.device}")
    self.sequences = []
    self.labels = []

batch_predict(dataloader, do_pred=True, output_hidden_states=False, output_attentions=False)

Perform batch prediction on sequences.

This method runs inference on batches of sequences and optionally extracts hidden states and attention weights for model interpretability.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader object containing sequences for inference

required
do_pred bool

Whether to convert logits to predictions

True
output_hidden_states bool

Whether to output hidden states from all layers

False
output_attentions bool

Whether to output attention weights from all layers

False

Returns:

Type Description
Tuple[Tensor, Optional[Dict], Dict]

Tuple containing: - torch.Tensor: All logits from the model - Optional[Dict]: Predictions dictionary if do_pred=True, otherwise None - Dict: Embeddings dictionary containing hidden states and/or attention weights

Note

Setting output_hidden_states or output_attentions to True will consume significant memory, especially for long sequences or large models.

Source code in dnallm/inference/predictor.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
@torch.no_grad()
def batch_predict(self, dataloader: DataLoader, do_pred: bool = True,
                  output_hidden_states: bool = False,
                  output_attentions: bool = False) -> Tuple[torch.Tensor, Optional[Dict], Dict]:
    """Perform batch prediction on sequences.

    This method runs inference on batches of sequences and optionally extracts
    hidden states and attention weights for model interpretability.

    Args:
        dataloader: DataLoader object containing sequences for inference
        do_pred: Whether to convert logits to predictions
        output_hidden_states: Whether to output hidden states from all layers
        output_attentions: Whether to output attention weights from all layers

    Returns:
        Tuple containing:
            - torch.Tensor: All logits from the model
            - Optional[Dict]: Predictions dictionary if do_pred=True, otherwise None
            - Dict: Embeddings dictionary containing hidden states and/or attention weights

    Note:
        Setting output_hidden_states or output_attentions to True will consume
        significant memory, especially for long sequences or large models.
    """
    # Set model to evaluation mode
    self.model.eval()
    all_logits = []
    # Whether or not to output hidden states
    params = None
    embeddings = {}
    if output_hidden_states:
        import inspect
        sig = inspect.signature(self.model.forward)
        params = sig.parameters
        if 'output_hidden_states' in params:
            self.model.config.output_hidden_states = True
        embeddings['hidden_states'] = None
        embeddings['attention_mask'] = []
        embeddings['labels'] = []
    if output_attentions:
        if not params:
            import inspect
            sig = inspect.signature(self.model.forward)
            params = sig.parameters
        if 'output_attentions' in params:
            self.model.config.output_attentions = True
        embeddings['attentions'] = None
    # Iterate over batches
    for batch in tqdm(dataloader, desc="Predicting"):
        inputs = {k: v.to(self.device) for k, v in batch.items()}
        if self.pred_config.use_fp16:
            self.model = self.model.half()
            with torch.amp.autocast('cuda'):
                outputs = self.model(**inputs)
        else:
            outputs = self.model(**inputs)
        # Get logits
        logits = outputs.logits.cpu().detach()
        all_logits.append(logits)
        # Get hidden states
        if output_hidden_states:
            hidden_states = [h.cpu().detach() for h in outputs.hidden_states] if hasattr(outputs, 'hidden_states') else None
            if embeddings['hidden_states'] is None:
                embeddings['hidden_states'] = [[h] for h in hidden_states]
            else:
                for i, h in enumerate(hidden_states):
                    embeddings['hidden_states'][i].append(h)
            attention_mask = inputs['attention_mask'].cpu().detach() if 'attention_mask' in inputs else None
            embeddings['attention_mask'].append(attention_mask)
            labels = inputs['labels'].cpu().detach() if 'labels' in inputs else None
            embeddings['labels'].append(labels)
        # Get attentions
        if output_attentions:
            attentions = [a.cpu().detach() for a in outputs.attentions] if hasattr(outputs, 'attentions') else None
            if attentions:
                if embeddings['attentions'] is None:
                    embeddings['attentions'] = [[a] for a in attentions]
                else:
                    for i, a in enumerate(attentions):
                        embeddings['attentions'][i].append(a)
    # Concatenate logits
    all_logits = torch.cat(all_logits, dim=0)
    if output_hidden_states:
        if embeddings['hidden_states']:
            embeddings['hidden_states'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['hidden_states'])
        if embeddings['attention_mask']:
            embeddings['attention_mask'] = torch.cat(embeddings['attention_mask'], dim=0)
        if embeddings['labels']:
            embeddings['labels'] = torch.cat(embeddings['labels'], dim=0)
    if output_attentions:
        if embeddings['attentions']:
            embeddings['attentions'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['attentions'])
    # Get predictions
    predictions = None
    if do_pred:
        predictions = self.logits_to_preds(all_logits)
        predictions = self.format_output(predictions)
    return all_logits, predictions, embeddings

calculate_metrics(logits, labels, plot=False)

Calculate evaluation metrics for model predictions.

This method computes task-specific evaluation metrics using the configured metrics computation module.

Parameters:

Name Type Description Default
logits Union[List, Tensor]

Model predictions (logits or probabilities)

required
labels Union[List, Tensor]

True labels for evaluation

required
plot bool

Whether to generate metric plots

False

Returns:

Type Description
Dict

Dictionary containing evaluation metrics for the task

Source code in dnallm/inference/predictor.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def calculate_metrics(self, logits: Union[List, torch.Tensor],
                      labels: Union[List, torch.Tensor], plot: bool = False) -> Dict:
    """Calculate evaluation metrics for model predictions.

    This method computes task-specific evaluation metrics using the configured
    metrics computation module.

    Args:
        logits: Model predictions (logits or probabilities)
        labels: True labels for evaluation
        plot: Whether to generate metric plots

    Returns:
        Dictionary containing evaluation metrics for the task
    """
    # Calculate metrics based on task type
    compute_metrics = Metrics(self.task_config, plot=plot)
    metrics = compute_metrics((logits, labels))

    return metrics

format_output(predictions)

Format output predictions into a structured dictionary.

This method converts raw predictions into a user-friendly format with sequences, labels, and confidence scores.

Parameters:

Name Type Description Default
predictions Tuple[Tensor, List]

Tuple containing (probabilities, labels)

required

Returns:

Type Description
Dict

Dictionary containing formatted predictions with structure:

Dict

{index: {'sequence': str, 'label': str/list, 'scores': dict/list}}

Source code in dnallm/inference/predictor.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def format_output(self, predictions: Tuple[torch.Tensor, List]) -> Dict:
    """Format output predictions into a structured dictionary.

    This method converts raw predictions into a user-friendly format with
    sequences, labels, and confidence scores.

    Args:
        predictions: Tuple containing (probabilities, labels)

    Returns:
        Dictionary containing formatted predictions with structure:
        {index: {'sequence': str, 'label': str/list, 'scores': dict/list}}
    """
    # Get task type from config
    task_type = self.task_config.task_type
    formatted_predictions = {}
    probs, labels = predictions
    probs = probs.numpy().tolist()
    keep_seqs = True if len(self.sequences) else False
    label_names = self.task_config.label_names
    for i, label in enumerate(labels):
        prob = probs[i]
        formatted_predictions[i] = {
            'sequence': self.sequences[i] if keep_seqs else '',
            'label': label,
            'scores': {label_names[j]: p for j, p in enumerate(prob)} if task_type != "token"
                      else [max(x) for x in prob],
        }
    return formatted_predictions

generate_dataset(seq_or_path, batch_size=1, seq_col='sequence', label_col='labels', sep=None, fasta_sep='|', multi_label_sep=None, uppercase=False, lowercase=False, keep_seqs=True, do_encode=True)

Generate dataset from sequences or file path.

This method creates a DNADataset and DataLoader from either a list of sequences or a file path, supporting various file formats and preprocessing options.

Parameters:

Name Type Description Default
seq_or_path Union[str, List[str]]

Single sequence, list of sequences, or path to a file containing sequences

required
batch_size int

Batch size for DataLoader

1
seq_col str

Column name for sequences in the file

'sequence'
label_col str

Column name for labels in the file

'labels'
sep str

Delimiter for CSV, TSV, or TXT files

None
fasta_sep str

Delimiter for FASTA files

'|'
multi_label_sep Union[str, None]

Delimiter for multi-label sequences

None
uppercase bool

Whether to convert sequences to uppercase

False
lowercase bool

Whether to convert sequences to lowercase

False
keep_seqs bool

Whether to keep sequences in the dataset for later use

True
do_encode bool

Whether to encode sequences for the model

True

Returns:

Type Description
Tuple[DNADataset, DataLoader]

Tuple containing: - DNADataset: Dataset object with sequences and labels - DataLoader: DataLoader object for batch processing

Raises:

Type Description
ValueError

If input is neither a file path nor a list of sequences

Source code in dnallm/inference/predictor.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def generate_dataset(self, seq_or_path: Union[str, List[str]], batch_size: int = 1,
                     seq_col: str = "sequence", label_col: str = "labels",
                     sep: str = None, fasta_sep: str = "|",
                     multi_label_sep: Union[str, None] = None,
                     uppercase: bool = False, lowercase: bool = False,
                     keep_seqs: bool = True, do_encode: bool = True) -> Tuple[DNADataset, DataLoader]:
    """Generate dataset from sequences or file path.

    This method creates a DNADataset and DataLoader from either a list of sequences
    or a file path, supporting various file formats and preprocessing options.

    Args:
        seq_or_path: Single sequence, list of sequences, or path to a file containing sequences
        batch_size: Batch size for DataLoader
        seq_col: Column name for sequences in the file
        label_col: Column name for labels in the file
        sep: Delimiter for CSV, TSV, or TXT files
        fasta_sep: Delimiter for FASTA files
        multi_label_sep: Delimiter for multi-label sequences
        uppercase: Whether to convert sequences to uppercase
        lowercase: Whether to convert sequences to lowercase
        keep_seqs: Whether to keep sequences in the dataset for later use
        do_encode: Whether to encode sequences for the model

    Returns:
        Tuple containing:
            - DNADataset: Dataset object with sequences and labels
            - DataLoader: DataLoader object for batch processing

    Raises:
        ValueError: If input is neither a file path nor a list of sequences
    """
    if isinstance(seq_or_path, str):
        suffix = seq_or_path.split(".")[-1]
        if suffix and os.path.isfile(seq_or_path):
            sequences = []
            dataset = DNADataset.load_local_data(seq_or_path, seq_col=seq_col, label_col=label_col,
                                                 sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                                 tokenizer=self.tokenizer, max_length=self.pred_config.max_length)
        else:
            sequences = [seq_or_path]
    elif isinstance(seq_or_path, list):
        sequences = seq_or_path
    else:
        raise ValueError("Input should be a file path or a list of sequences.")
    if len(sequences) > 0:
        ds = Dataset.from_dict({"sequence": sequences})
        dataset = DNADataset(ds, self.tokenizer, max_length=self.pred_config.max_length)
    # If labels are provided, keep labels
    if keep_seqs:
        self.sequences = dataset.dataset["sequence"]
    # Encode sequences
    if do_encode:
        task_type = self.task_config.task_type
        dataset.encode_sequences(remove_unused_columns=True, task=task_type, uppercase=uppercase, lowercase=lowercase)
    if "labels" in dataset.dataset.features:
        self.labels = dataset.dataset["labels"]
    # Create DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=self.pred_config.num_workers
    )

    return dataset, dataloader

logits_to_preds(logits)

Convert model logits to predictions and human-readable labels.

This method processes raw model outputs based on the task type to generate appropriate predictions and convert them to human-readable labels.

Parameters:

Name Type Description Default
logits Tensor

Model output logits tensor

required

Returns:

Type Description
Tuple[Tensor, List]

Tuple containing: - torch.Tensor: Model predictions (probabilities or raw values) - List: Human-readable labels corresponding to predictions

Raises:

Type Description
ValueError

If task type is not supported

Source code in dnallm/inference/predictor.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def logits_to_preds(self, logits: torch.Tensor) -> Tuple[torch.Tensor, List]:
    """Convert model logits to predictions and human-readable labels.

    This method processes raw model outputs based on the task type to generate
    appropriate predictions and convert them to human-readable labels.

    Args:
        logits: Model output logits tensor

    Returns:
        Tuple containing:
            - torch.Tensor: Model predictions (probabilities or raw values)
            - List: Human-readable labels corresponding to predictions

    Raises:
        ValueError: If task type is not supported
    """
    # Get task type and threshold from config
    task_type = self.task_config.task_type
    threshold = self.task_config.threshold
    label_names = self.task_config.label_names
    # Convert logits to predictions based on task type
    if task_type == "binary":
        probs = torch.softmax(logits, dim=-1)
        preds = (probs[:, 1] > threshold).long()
        labels = [label_names[pred] for pred in preds]
    elif task_type == "multiclass":
        probs = torch.softmax(logits, dim=-1)
        preds = torch.argmax(probs, dim=-1)
        labels = [label_names[pred] for pred in preds]
    elif task_type == "multilabel":
        probs = torch.sigmoid(logits)
        preds = (probs > threshold).long()
        labels = []
        for pred in preds:
            label = [label_names[i] for i in range(len(pred)) if pred[i] == 1]
            labels.append(label)
    elif task_type == "regression":
        preds = logits.squeeze(-1)
        probs = preds
        labels = preds.tolist()
    elif task_type == "token":
        probs = torch.softmax(logits, dim=-1)
        preds = torch.argmax(logits, dim=-1)
        labels = []
        for pred in preds:
            label = [label_names[pred[i]] for i in range(len(pred))]
            labels.append(label)
    else:
        raise ValueError(f"Unsupported task type: {task_type}")
    return probs, labels

plot_attentions(seq_idx=0, layer=-1, head=-1, width=800, height=800, save_path=None)

Plot attention map visualization.

This method creates a heatmap visualization of attention weights between tokens in a sequence, showing how the model attends to different parts of the input.

Parameters:

Name Type Description Default
seq_idx int

Index of the sequence to plot, default 0

0
layer int

Layer index to visualize, default -1 (last layer)

-1
head int

Attention head index to visualize, default -1 (last head)

-1
width int

Width of the plot

800
height int

Height of the plot

800
save_path Optional[str]

Path to save the plot. If None, plot will be shown interactively

None

Returns:

Type Description
Optional[Any]

Attention map visualization if available, otherwise None

Note

This method requires that attention weights were collected during inference by setting output_attentions=True in prediction methods

Source code in dnallm/inference/predictor.py
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
def plot_attentions(self, seq_idx: int = 0, layer: int = -1, head: int = -1,
                    width: int = 800, height: int = 800,
                    save_path: Optional[str] = None) -> Optional[Any]:
    """Plot attention map visualization.

    This method creates a heatmap visualization of attention weights between tokens
    in a sequence, showing how the model attends to different parts of the input.

    Args:
        seq_idx: Index of the sequence to plot, default 0
        layer: Layer index to visualize, default -1 (last layer)
        head: Attention head index to visualize, default -1 (last head)
        width: Width of the plot
        height: Height of the plot
        save_path: Path to save the plot. If None, plot will be shown interactively

    Returns:
        Attention map visualization if available, otherwise None

    Note:
        This method requires that attention weights were collected during inference
        by setting output_attentions=True in prediction methods
    """
    if hasattr(self, 'embeddings'):
        attentions = self.embeddings['attentions']
        if save_path:
            suffix = os.path.splitext(save_path)[-1]
            if suffix:
                heatmap = save_path.replace(suffix, "_heatmap" + suffix)
            else:
                heatmap = os.path.join(save_path, "heatmap.pdf")
        else:
            heatmap = None
        # Plot attention map
        attn_map = plot_attention_map(attentions, self.sequences, self.tokenizer,
                                      seq_idx=seq_idx, layer=layer, head=head,
                                      width=width, height=height,
                                      save_path=heatmap)
        return attn_map
    else:
        print("No attention weights available to plot.")

plot_hidden_states(reducer='t-SNE', ncols=4, width=300, height=300, save_path=None)

Visualize embeddings using dimensionality reduction.

This method creates 2D visualizations of high-dimensional embeddings from different model layers using PCA, t-SNE, or UMAP dimensionality reduction.

Parameters:

Name Type Description Default
reducer str

Dimensionality reduction method to use ('PCA', 't-SNE', 'UMAP')

't-SNE'
ncols int

Number of columns in the plot grid

4
width int

Width of each plot

300
height int

Height of each plot

300
save_path Optional[str]

Path to save the plot. If None, plot will be shown interactively

None

Returns:

Type Description
Optional[Any]

Embedding visualization if available, otherwise None

Note

This method requires that hidden states were collected during inference by setting output_hidden_states=True in prediction methods

Source code in dnallm/inference/predictor.py
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
def plot_hidden_states(self, reducer: str = "t-SNE",
                       ncols: int = 4, width: int = 300, height: int = 300,
                       save_path: Optional[str] = None) -> Optional[Any]:
    """Visualize embeddings using dimensionality reduction.

    This method creates 2D visualizations of high-dimensional embeddings from
    different model layers using PCA, t-SNE, or UMAP dimensionality reduction.

    Args:
        reducer: Dimensionality reduction method to use ('PCA', 't-SNE', 'UMAP')
        ncols: Number of columns in the plot grid
        width: Width of each plot
        height: Height of each plot
        save_path: Path to save the plot. If None, plot will be shown interactively

    Returns:
        Embedding visualization if available, otherwise None

    Note:
        This method requires that hidden states were collected during inference
        by setting output_hidden_states=True in prediction methods
    """
    if hasattr(self, 'embeddings'):
        hidden_states = self.embeddings['hidden_states']
        attention_mask = torch.unsqueeze(self.embeddings['attention_mask'], dim=-1)
        labels = self.embeddings['labels']
        if save_path:
            suffix = os.path.splitext(save_path)[-1]
            if suffix:
                embedding = save_path.replace(suffix, "_embedding" + suffix)
            else:
                embedding = os.path.join(save_path, "embedding.pdf")
        else:
            embedding = None
        # Plot hidden states
        label_names = self.task_config.label_names
        embeddings_vis = plot_embeddings(hidden_states, attention_mask, reducer=reducer,
                                         labels=labels, label_names=label_names,
                                         ncols=ncols, width=width, height=height,
                                         save_path=embedding)
        return embeddings_vis
    else:
        print("No hidden states available to plot.")

predict_file(file_path, evaluate=False, output_hidden_states=False, output_attentions=False, seq_col='sequence', label_col='labels', sep=None, fasta_sep='|', multi_label_sep=None, uppercase=False, lowercase=False, save_to_file=False, plot_metrics=False)

Predict from a file containing sequences.

This method loads sequences from a file and performs prediction, with optional evaluation, visualization, and saving capabilities.

Parameters:

Name Type Description Default
file_path str

Path to the file containing sequences

required
evaluate bool

Whether to evaluate predictions against true labels

False
output_hidden_states bool

Whether to output hidden states for visualization

False
output_attentions bool

Whether to output attention weights for visualization

False
seq_col str

Column name for sequences in the file

'sequence'
label_col str

Column name for labels in the file

'labels'
sep str

Delimiter for CSV, TSV, or TXT files

None
fasta_sep str

Delimiter for FASTA files

'|'
multi_label_sep Union[str, None]

Delimiter for multi-label sequences

None
uppercase bool

Whether to convert sequences to uppercase

False
lowercase bool

Whether to convert sequences to lowercase

False
save_to_file bool

Whether to save predictions and metrics to output directory

False
plot_metrics bool

Whether to generate metric plots

False

Returns:

Name Type Description
Either Union[Dict, Tuple[Dict, Dict]]
  • Dict: Dictionary containing predictions
  • Tuple[Dict, Dict]: (predictions, metrics) if evaluate=True
Note

Setting output_attentions=True may consume significant memory

Source code in dnallm/inference/predictor.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def predict_file(self, file_path: str, evaluate: bool = False,
                 output_hidden_states: bool = False,
                 output_attentions: bool = False,
                 seq_col: str = "sequence", label_col: str = "labels",
                 sep: str = None, fasta_sep: str = "|",
                 multi_label_sep: Union[str, None] = None,
                 uppercase: bool = False, lowercase: bool = False,
                 save_to_file: bool = False, plot_metrics: bool = False) -> Union[Dict, Tuple[Dict, Dict]]:
    """Predict from a file containing sequences.

    This method loads sequences from a file and performs prediction,
    with optional evaluation, visualization, and saving capabilities.

    Args:
        file_path: Path to the file containing sequences
        evaluate: Whether to evaluate predictions against true labels
        output_hidden_states: Whether to output hidden states for visualization
        output_attentions: Whether to output attention weights for visualization
        seq_col: Column name for sequences in the file
        label_col: Column name for labels in the file
        sep: Delimiter for CSV, TSV, or TXT files
        fasta_sep: Delimiter for FASTA files
        multi_label_sep: Delimiter for multi-label sequences
        uppercase: Whether to convert sequences to uppercase
        lowercase: Whether to convert sequences to lowercase
        save_to_file: Whether to save predictions and metrics to output directory
        plot_metrics: Whether to generate metric plots

    Returns:
        Either:
            - Dict: Dictionary containing predictions
            - Tuple[Dict, Dict]: (predictions, metrics) if evaluate=True

    Note:
        Setting output_attentions=True may consume significant memory
    """
    # Get dataset and dataloader from file
    _, dataloader = self.generate_dataset(file_path, seq_col=seq_col, label_col=label_col,
                                          sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                          uppercase=uppercase, lowercase=lowercase,
                                          batch_size=self.pred_config.batch_size)
    # Do batch prediction
    if output_attentions:
        warnings.warn("Cautions: output_attentions may consume a lot of memory.\n")
    logits, predictions, embeddings = self.batch_predict(dataloader,
                                                         output_hidden_states=output_hidden_states,
                                                         output_attentions=output_attentions)
    # Keep hidden states
    if output_hidden_states or output_attentions:
        self.embeddings = embeddings
    # Save predictions
    if save_to_file and self.pred_config.output_dir:
        save_predictions(predictions, Path(self.pred_config.output_dir))
    # Do evaluation
    if len(self.labels) == len(logits) and evaluate:
        metrics = self.calculate_metrics(logits, self.labels, plot=plot_metrics)
        metrics_save = dict(metrics)
        if 'curve' in metrics_save:
            del metrics_save['curve']
        if 'scatter' in metrics_save:
            del metrics_save['scatter']
        if save_to_file and self.pred_config.output_dir:
            save_metrics(metrics, Path(self.pred_config.output_dir))
        # Whether to plot metrics
        if plot_metrics:
            return predictions, metrics
        else:
            return predictions, metrics_save

    return predictions

predict_seqs(sequences, evaluate=False, output_hidden_states=False, output_attentions=False, save_to_file=False)

Predict for a list of sequences.

This method provides a convenient interface for predicting on sequences, with optional evaluation and saving capabilities.

Parameters:

Name Type Description Default
sequences Union[str, List[str]]

Single sequence or list of sequences for prediction

required
evaluate bool

Whether to evaluate predictions against true labels

False
output_hidden_states bool

Whether to output hidden states for visualization

False
output_attentions bool

Whether to output attention weights for visualization

False
save_to_file bool

Whether to save predictions to output directory

False

Returns:

Name Type Description
Either Union[Dict, Tuple[Dict, Dict]]
  • Dict: Dictionary containing predictions
  • Tuple[Dict, Dict]: (predictions, metrics) if evaluate=True
Note

Evaluation requires that labels are available in the dataset

Source code in dnallm/inference/predictor.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def predict_seqs(self, sequences: Union[str, List[str]],
                 evaluate: bool = False,
                 output_hidden_states: bool = False,
                 output_attentions: bool = False,
                 save_to_file: bool = False) -> Union[Dict, Tuple[Dict, Dict]]:
    """Predict for a list of sequences.

    This method provides a convenient interface for predicting on sequences,
    with optional evaluation and saving capabilities.

    Args:
        sequences: Single sequence or list of sequences for prediction
        evaluate: Whether to evaluate predictions against true labels
        output_hidden_states: Whether to output hidden states for visualization
        output_attentions: Whether to output attention weights for visualization
        save_to_file: Whether to save predictions to output directory

    Returns:
        Either:
            - Dict: Dictionary containing predictions
            - Tuple[Dict, Dict]: (predictions, metrics) if evaluate=True

    Note:
        Evaluation requires that labels are available in the dataset
    """
    # Get dataset and dataloader from sequences
    _, dataloader = self.generate_dataset(sequences, batch_size=self.pred_config.batch_size)
    # Do batch prediction
    logits, predictions, embeddings = self.batch_predict(dataloader,
                                                         output_hidden_states=output_hidden_states,
                                                         output_attentions=output_attentions)
    # Keep hidden states
    if output_hidden_states or output_attentions:
        self.embeddings = embeddings
    # Save predictions
    if save_to_file and self.pred_config.output_dir:
        save_predictions(predictions, Path(self.pred_config.output_dir))
    # Do evaluation
    if len(self.labels) == len(logits) and evaluate:
        metrics = self.calculate_metrics(logits, self.labels)
        metrics_save = dict(metrics)
        if 'curve' in metrics_save:
            del metrics_save['curve']
        if 'scatter' in metrics_save:
            del metrics_save['scatter']
        if save_to_file and self.pred_config.output_dir:
            save_metrics(metrics_save, Path(self.pred_config.output_dir))
        return predictions, metrics

    return predictions

generate(self, dataloader, n_tokens=400, temperature=1.0, top_k=4)

Generate DNA sequences using the model.

This function performs sequence generation tasks using the loaded model, currently supporting EVO2 models for DNA sequence generation.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader containing prompt sequences

required
n_tokens int

Number of tokens to generate, default 400

400
temperature float

Sampling temperature for generation, default 1.0

1.0
top_k int

Top-k sampling parameter, default 4

4

Returns:

Type Description
Dict

Dictionary containing generated sequences

Note

Currently only supports EVO2 models for sequence generation

Source code in dnallm/inference/predictor.py
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
def generate(self, dataloader: DataLoader, n_tokens: int = 400, temperature: float = 1.0,
             top_k: int = 4) -> Dict:
    """Generate DNA sequences using the model.

    This function performs sequence generation tasks using the loaded model,
    currently supporting EVO2 models for DNA sequence generation.

    Args:
        dataloader: DataLoader containing prompt sequences
        n_tokens: Number of tokens to generate, default 400
        temperature: Sampling temperature for generation, default 1.0
        top_k: Top-k sampling parameter, default 4

    Returns:
        Dictionary containing generated sequences

    Note:
        Currently only supports EVO2 models for sequence generation
    """
    if "evo2" in str(self.model):
        for data in tqdm(dataloader, desc="Generating"):
            prompt_seqs = data['sequence']
            if isinstance(prompt_seqs, list):
                prompt_seqs = [seq for seq in prompt_seqs if seq]
            if not prompt_seqs:
                continue
            # Generate sequences
            output = self.model.generate(prompt_seqs=prompt_seqs, n_tokens=n_tokens,
                                         temperature=temperature, top_k=top_k)
            return output

plot_attention_map(attentions, sequences, tokenizer, seq_idx=0, layer=-1, head=-1, width=800, height=800, save_path=None)

Plot attention map visualization for transformer models.

This function creates a heatmap visualization of attention weights between tokens in a sequence, showing how the model attends to different parts of the input.

Parameters:

Name Type Description Default
attentions Union[tuple, list]

Tuple or list containing attention weights from model layers

required
sequences list

List of input sequences

required
tokenizer

Tokenizer object for converting tokens to readable text

required
seq_idx int

Index of the sequence to plot, default 0

0
layer int

Layer index to visualize, default -1 (last layer)

-1
head int

Attention head index to visualize, default -1 (last head)

-1
width int

Width of the plot

800
height int

Height of the plot

800
save_path str

Path to save the plot. If None, plot will be shown interactively

None

Returns:

Type Description
Chart

Altair chart object showing the attention heatmap

Source code in dnallm/inference/plot.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def plot_attention_map(attentions: Union[tuple, list], sequences: list, tokenizer,
                       seq_idx: int = 0, layer: int = -1, head: int = -1,
                       width: int = 800, height: int = 800,
                       save_path: str = None) -> alt.Chart:
    """Plot attention map visualization for transformer models.

    This function creates a heatmap visualization of attention weights between tokens
    in a sequence, showing how the model attends to different parts of the input.

    Args:
        attentions: Tuple or list containing attention weights from model layers
        sequences: List of input sequences
        tokenizer: Tokenizer object for converting tokens to readable text
        seq_idx: Index of the sequence to plot, default 0
        layer: Layer index to visualize, default -1 (last layer)
        head: Attention head index to visualize, default -1 (last head)
        width: Width of the plot
        height: Height of the plot
        save_path: Path to save the plot. If None, plot will be shown interactively

    Returns:
        Altair chart object showing the attention heatmap
    """
    # Plot attention map
    attn_layer = attentions[layer].numpy()
    attn_head = attn_layer[seq_idx][head]
    # Get the tokens
    seq = sequences[seq_idx]
    tokens_id = tokenizer.encode(seq)
    try:
        tokens = tokenizer.convert_ids_to_tokens(tokens_id)
    except:
        tokens = tokenizer.decode(seq).split()
    # Create a DataFrame for the attention map
    num_tokens = len(tokens)
    flen = len(str(num_tokens))
    df = {"token1": [], 'token2': [], 'attn': []}
    for i, t1 in enumerate(tokens):
        for j, t2 in enumerate(tokens):
            df["token1"].append(str(i).zfill(flen)+t1)
            df["token2"].append(str(num_tokens-j).zfill(flen)+t2)
            df["attn"].append(attn_head[i][j])
    source = pd.DataFrame(df)
    # Enable VegaFusion for Altair
    alt.data_transformers.enable("vegafusion")
    # Plot the attention map
    attn_map = alt.Chart(source).mark_rect().encode(
        x=alt.X('token1:O', axis=alt.Axis(
                    labelExpr = f"substring(datum.value, {flen}, 100)",
                    labelAngle=-45,
                    )
                ).title(None),
        y=alt.Y('token2:O', axis=alt.Axis(
                    labelExpr = f"substring(datum.value, {flen}, 100)",
                    labelAngle=0,
                    )
                ).title(None),
        color=alt.Color('attn:Q').scale(scheme='viridis'),
    ).properties(
        width=width,
        height=height
    ).configure_axis(grid=False)
    # Save the plot
    if save_path:
        attn_map.save(save_path)
        print(f"Attention map saved to {save_path}")
    return attn_map

plot_bars(data, show_score=True, ncols=3, width=200, height=50, bar_width=30, domain=(0.0, 1.0), save_path=None, separate=False)

Plot bar charts for model metrics comparison.

This function creates bar charts to compare different metrics across multiple models. It supports automatic layout with multiple columns and optional score labels on bars.

Parameters:

Name Type Description Default
data dict

Dictionary containing metrics data with 'models' as the first key

required
show_score bool

Whether to show the score values on the bars

True
ncols int

Number of columns to arrange the plots

3
width int

Width of each individual plot

200
height int

Height of each individual plot

50
bar_width int

Width of the bars in the plot

30
domain Union[tuple, list]

Y-axis domain range for the plots, default (0.0, 1.0)

(0.0, 1.0)
save_path str

Path to save the plot. If None, plot will be shown interactively

None
separate bool

Whether to return separate plots for each metric

False

Returns:

Type Description
Chart

Altair chart object (combined or separate plots based on separate parameter)

Source code in dnallm/inference/plot.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def plot_bars(data: dict, show_score: bool = True, ncols: int = 3,
              width: int = 200, height: int = 50, bar_width: int = 30,
              domain: Union[tuple, list] = (0.0, 1.0),
              save_path: str = None, separate: bool = False) -> alt.Chart:
    """Plot bar charts for model metrics comparison.

    This function creates bar charts to compare different metrics across multiple models.
    It supports automatic layout with multiple columns and optional score labels on bars.

    Args:
        data: Dictionary containing metrics data with 'models' as the first key
        show_score: Whether to show the score values on the bars
        ncols: Number of columns to arrange the plots
        width: Width of each individual plot
        height: Height of each individual plot
        bar_width: Width of the bars in the plot
        domain: Y-axis domain range for the plots, default (0.0, 1.0)
        save_path: Path to save the plot. If None, plot will be shown interactively
        separate: Whether to return separate plots for each metric

    Returns:
        Altair chart object (combined or separate plots based on separate parameter)
    """
    # Plot bar charts
    dbar = pd.DataFrame(data)
    pbar = {}
    p_separate = {}
    for n, metric in enumerate([x for x in data if x != 'models']):
        if metric in ['mae', 'mse']:
            domain_use = [0, dbar[metric].max()*1.1]
        else:
            domain_use = domain
        bar = alt.Chart(dbar).mark_bar(size=bar_width).encode(
            x=alt.X(metric + ":Q").scale(domain=domain_use),
            y=alt.Y("models").title(None),
            color=alt.Color('models').legend(None),
        ).properties(width=width, height=height*len(dbar['models']))
        if show_score:
            text = alt.Chart(dbar).mark_text(
                dx=-10,
                color='white',
                baseline='middle',
                align='right').encode(
                    x=alt.X(metric + ":Q"),
                    y=alt.Y("models").title(None),
                    text=alt.Text(metric, format='.3f')
                    )
            p = bar + text
        else:
            p = bar
        if separate:
            p_separate[metric] = p.configure_axis(grid=False)
        idx = n // ncols
        if n % ncols == 0:
            pbar[idx] = p
        else:
            pbar[idx] |= p
    # Combine the plots
    for i, p in enumerate(pbar):
        if i == 0:
            pbars = pbar[p]
        else:
            pbars &= pbar[p]
    # Configure the chart
    pbars = pbars.configure_axis(grid=False)
    # Save the plot
    if save_path:
        pbars.save(save_path)
        print(f"Metrics bar charts saved to {save_path}")
    if separate:
        return p_separate
    else:
        return pbars

plot_curve(data, show_score=True, width=400, height=400, save_path=None, separate=False)

Plot ROC and PR curves for classification tasks.

This function creates ROC (Receiver Operating Characteristic) and PR (Precision-Recall) curves to evaluate model performance on classification tasks.

Parameters:

Name Type Description Default
data dict

Dictionary containing ROC and PR curve data with 'ROC' and 'PR' keys

required
show_score bool

Whether to show the score values on the plot (currently not implemented)

True
width int

Width of each plot

400
height int

Height of each plot

400
save_path str

Path to save the plot. If None, plot will be shown interactively

None
separate bool

Whether to return separate plots for ROC and PR curves

False

Returns:

Type Description
Chart

Altair chart object (combined or separate plots based on separate parameter)

Source code in dnallm/inference/plot.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def plot_curve(data: dict, show_score: bool = True,
               width: int = 400, height: int = 400,
               save_path: str = None, separate: bool = False) -> alt.Chart:
    """Plot ROC and PR curves for classification tasks.

    This function creates ROC (Receiver Operating Characteristic) and PR (Precision-Recall)
    curves to evaluate model performance on classification tasks.

    Args:
        data: Dictionary containing ROC and PR curve data with 'ROC' and 'PR' keys
        show_score: Whether to show the score values on the plot (currently not implemented)
        width: Width of each plot
        height: Height of each plot
        save_path: Path to save the plot. If None, plot will be shown interactively
        separate: Whether to return separate plots for ROC and PR curves

    Returns:
        Altair chart object (combined or separate plots based on separate parameter)
    """
    # Plot curves
    pline = {}
    p_separate = {}
    # Plot ROC curve
    roc_data = pd.DataFrame(data['ROC'])
    pline[0] = alt.Chart(roc_data).mark_line().encode(
        x=alt.X("fpr").scale(domain=(0.0, 1.0)),
        y=alt.Y("tpr").scale(domain=(0.0, 1.0)),
        color="models",
    ).properties(width=width, height=height)
    if separate:
        p_separate['ROC'] = pline[0]
    # Plot PR curve
    pr_data = pd.DataFrame(data['PR'])
    pline[1] = alt.Chart(pr_data).mark_line().encode(
        x=alt.X("recall").scale(domain=(0.0, 1.0)),
        y=alt.Y("precision").scale(domain=(0.0, 1.0)),
        color="models",
    ).properties(width=width, height=height)
    if separate:
        p_separate['PR'] = pline[1]
    # Combine the plots
    for i, p in enumerate(pline):
        if i == 0:
            plines = pline[i]
        else:
            plines |= pline[i]
    # Configure the chart
    plines = plines.configure_axis(grid=False)
    # Save the plot
    if save_path:
        plines.save(save_path)
        print(f"ROC curves saved to {save_path}")
    if separate:
        return p_separate
    else:
        return plines

plot_embeddings(hidden_states, attention_mask, reducer='t-SNE', labels=None, label_names=None, ncols=4, width=300, height=300, save_path=None, separate=False)

Visualize embeddings using dimensionality reduction techniques.

This function creates 2D visualizations of high-dimensional embeddings from different model layers using PCA, t-SNE, or UMAP dimensionality reduction methods.

Parameters:

Name Type Description Default
hidden_states Union[tuple, list]

Tuple or list containing hidden states from model layers

required
attention_mask Union[tuple, list]

Tuple or list containing attention masks for sequence padding

required
reducer str

Dimensionality reduction method. Options: 'PCA', 't-SNE', 'UMAP'

't-SNE'
labels Union[tuple, list]

List of labels for the data points

None
label_names Union[str, list]

List of label names for legend display

None
ncols int

Number of columns to arrange the plots

4
width int

Width of each plot

300
height int

Height of each plot

300
save_path str

Path to save the plot. If None, plot will be shown interactively

None
separate bool

Whether to return separate plots for each layer

False

Returns:

Type Description
Chart

Altair chart object (combined or separate plots based on separate parameter)

Raises:

Type Description
ValueError

If unsupported dimensionality reduction method is specified

Source code in dnallm/inference/plot.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def plot_embeddings(hidden_states: Union[tuple, list], attention_mask: Union[tuple, list], reducer: str = "t-SNE",
                    labels: Union[tuple, list] = None, label_names: Union[str, list] = None,
                    ncols: int = 4, width: int = 300, height: int = 300,
                    save_path: str = None, separate: bool = False) -> alt.Chart:
    """Visualize embeddings using dimensionality reduction techniques.

    This function creates 2D visualizations of high-dimensional embeddings from different
    model layers using PCA, t-SNE, or UMAP dimensionality reduction methods.

    Args:
        hidden_states: Tuple or list containing hidden states from model layers
        attention_mask: Tuple or list containing attention masks for sequence padding
        reducer: Dimensionality reduction method. Options: 'PCA', 't-SNE', 'UMAP'
        labels: List of labels for the data points
        label_names: List of label names for legend display
        ncols: Number of columns to arrange the plots
        width: Width of each plot
        height: Height of each plot
        save_path: Path to save the plot. If None, plot will be shown interactively
        separate: Whether to return separate plots for each layer

    Returns:
        Altair chart object (combined or separate plots based on separate parameter)

    Raises:
        ValueError: If unsupported dimensionality reduction method is specified
    """
    import torch
    if reducer.lower() == "pca":
        from sklearn.decomposition import PCA
        dim_reducer = PCA(n_components=2)
    elif reducer.lower() == "t-sne":
        from sklearn.manifold import TSNE
        dim_reducer = TSNE(n_components=2)
    elif reducer.lower() == "umap":
        from umap import UMAP
        dim_reducer = UMAP(n_components=2)
    else:
        raise("Unsupported dim reducer, please try PCA, t-SNE or UMAP.")

    pdot = {}
    p_separate = {}
    for i, hidden in enumerate(hidden_states):
        embeddings = hidden.numpy()
        mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2) / torch.sum(attention_mask, axis=1)
        layer_dim_reduced_vectors = dim_reducer.fit_transform(mean_sequence_embeddings.numpy())
        if len(labels) == 0:
            labels = ["Uncategorized"] * layer_dim_reduced_vectors.shape[0]
        df = {
            'Dimension 1': layer_dim_reduced_vectors[:,0],
            'Dimension 2': layer_dim_reduced_vectors[:,1],
            'labels': [label_names[int(i)] for i in labels]
        }
        source = pd.DataFrame(df)
        dot = alt.Chart(source, title=f"Layer {i+1}").mark_point(filled=True).encode(
            x=alt.X("Dimension 1:Q"),
            y=alt.Y("Dimension 2:Q"),
            color=alt.Color("labels:N", legend=alt.Legend(title="Labels")),
        ).properties(width=width, height=height)
        if separate:
            p_separate[f"Layer{i+1}"] = dot.configure_axis(grid=False)
        idx = i // ncols
        if i % ncols == 0:
            pdot[idx] = dot
        else:
            pdot[idx] |= dot
    # Combine the plots
    for i, p in enumerate(pdot):
        if i == 0:
            pdots = pdot[p]
        else:
            pdots &= pdot[p]
    # Configure the chart
    pdots = pdots.configure_axis(grid=False)
    # Save the plot
    if save_path:
        pdots.save(save_path)
        print(f"Embeddings visualization saved to {save_path}")
    if separate:
        return p_separate
    else:
        return pdots

plot_muts(data, show_score=False, width=None, height=100, save_path=None)

Visualize mutation effects on model predictions.

This function creates comprehensive visualizations of how different mutations affect model predictions, including: - Heatmap showing mutation effects at each position - Line plot showing gain/loss of function - Bar chart showing maximum effect mutations

Parameters:

Name Type Description Default
data dict

Dictionary containing mutation data with 'raw' and mutation keys

required
show_score bool

Whether to show the score values on the plot (currently not implemented)

False
width int

Width of the plot. If None, automatically calculated based on sequence length

None
height int

Height of the plot

100
save_path str

Path to save the plot. If None, plot will be shown interactively

None

Returns:

Type Description
Chart

Altair chart object showing the combined mutation effects visualization

Source code in dnallm/inference/plot.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
def plot_muts(data: dict, show_score: bool = False,
              width: int = None, height: int = 100,
              save_path: str = None) -> alt.Chart:
    """Visualize mutation effects on model predictions.

    This function creates comprehensive visualizations of how different mutations
    affect model predictions, including:
    - Heatmap showing mutation effects at each position
    - Line plot showing gain/loss of function
    - Bar chart showing maximum effect mutations

    Args:
        data: Dictionary containing mutation data with 'raw' and mutation keys
        show_score: Whether to show the score values on the plot (currently not implemented)
        width: Width of the plot. If None, automatically calculated based on sequence length
        height: Height of the plot
        save_path: Path to save the plot. If None, plot will be shown interactively

    Returns:
        Altair chart object showing the combined mutation effects visualization
    """
    # Create dataframe
    seqlen = len(data['raw']['sequence'])
    flen = len(str(seqlen))
    mut_list = [x for x in data.keys() if x != 'raw']
    raw_bases = [base for base in data['raw']['sequence']]
    dheat = {"base": [], 'mut': [], 'score': []}
    dline = {"x": [str(i).zfill(flen)+x for i,x in enumerate(raw_bases)] * 2,
             "score": [0.0]*seqlen*2,
             "type": ["gain"]*seqlen + ["loss"]*seqlen}
    dbar = {"x": [], "score": [], "base": []}
    # Iterate through mutations
    for i, base1 in enumerate(raw_bases):
        ref = "mut_" + str(i) + "_" + base1 + "_" + base1
        # replacement
        mut_prefix = "mut_" + str(i) + "_" + base1 + "_"
        maxabs = 0.0
        maxscore = 0.0
        maxabs_index = base1
        for mut in sorted([x for x in mut_list if x.startswith(mut_prefix)] + [ref]):
            if mut in data:
                # for heatmap
                base2 = mut.split("_")[-1]
                score = data[mut]['score']
                dheat["base"].append(str(i).zfill(flen)+base1)
                dheat["mut"].append(base2)
                dheat["score"].append(score)
                # for line
                if score >= 0:
                    dline["score"][i] += score
                elif score < 0:
                    dline["score"][i+seqlen] -= score
                # for bar chart
                if abs(score) > maxabs:
                    maxabs = abs(score)
                    maxscore = score
                    maxabs_index = base2
            else:
                dheat["base"].append(str(i).zfill(flen)+base1)
                dheat["mut"].append(base1)
                dheat["score"].append(0.0)
        # for bar chart
        dbar["x"].append(str(i).zfill(flen)+base1)
        dbar["score"].append(maxscore)
        dbar["base"].append(maxabs_index)
        # deletion
        del_prefix = "del_" + str(i) + "_"
        for mut in [x for x in mut_list if x.startswith(del_prefix)]:
            base2 = "del_" + mut.split("_")[-1]
            score = data[mut]['score']
            dheat["base"].append(str(i).zfill(flen)+base1)
            dheat["mut"].append(base2)
            dheat["score"].append(score)
        # insertion
        ins_prefix = "ins_" + str(i) + "_"
        for mut in [x for x in mut_list if x.startswith(ins_prefix)]:
            base2 = "ins_" + mut.split("_")[-1]
            score = data[mut]['score']
            dheat["base"].append(str(i).zfill(flen)+base1)
            dheat["mut"].append(base2)
            dheat["score"].append(score)
    # Set color domain and range
    domain1_min = min([data[mut]['score'] for mut in data])
    domain1_max = max([data[mut]['score'] for mut in data])
    domain1 = [-max([abs(domain1_min), abs(domain1_max)]),
               0.0,
               max([abs(domain1_min), abs(domain1_max)])]
    range1_ = ['#2166ac', '#f7f7f7', '#b2182b']
    domain2 = sorted([x for x in set(dbar['base'])])
    range2_ = ["#33a02c", "#e31a1c", "#1f78b4", "#ff7f00", "#cab2d6"][:len(domain2)]
    # Enable VegaFusion for Altair
    alt.data_transformers.enable("vegafusion")
    # Plot the heatmap
    if width is None:
        width = int(height * len(raw_bases) / len(set(dheat['mut'])))
    if dheat['base']:
        pheat = alt.Chart(pd.DataFrame(dheat)).mark_rect().encode(
            x=alt.X('base:O', axis=alt.Axis(
                    labelExpr = f"substring(datum.value, {flen}, {flen}+1)",
                    labelAngle=0,
                    )
                ).title(None),
            y=alt.Y('mut:O').title("mutation"),
            color=alt.Color('score:Q').scale(domain=domain1, range=range1_),
        ).properties(
            width=width, height=height
        )
        # Plot gain and loss
        pline = alt.Chart(pd.DataFrame(dline)).mark_line().encode(
            x=alt.X('x:O').title(None).axis(labels=False),
            y=alt.Y('score:Q'),
            color=alt.Color('type:N').scale(
                domain=['gain', 'loss'], range=['#b2182b', '#2166ac']
            ),
        ).properties(
            width=width, height=height
        )
        pbar = alt.Chart(pd.DataFrame(dbar)).mark_bar().encode(
            x=alt.X('x:O').title(None).axis(labels=False),
            y=alt.Y('score:Q'),
            color=alt.Color('base:N').scale(
                domain=domain2, range=range2_
            ),
        ).properties(
            width=width, height=height
        )
        pmerge = pheat & pbar & pline
        pmerge = pmerge.configure_axis(grid=False)
        # Save the plot
        if save_path:
            pmerge.save(save_path)
            print(f"Mutation effects visualization saved to {save_path}")
    return pheat

plot_scatter(data, show_score=True, ncols=3, width=400, height=400, save_path=None, separate=False)

Plot scatter plots for regression task evaluation.

This function creates scatter plots to compare predicted vs. experimental values for regression tasks, with optional R² score display.

Parameters:

Name Type Description Default
data dict

Dictionary containing scatter plot data for each model

required
show_score bool

Whether to show the R² score on the plot

True
ncols int

Number of columns to arrange the plots

3
width int

Width of each plot

400
height int

Height of each plot

400
save_path str

Path to save the plot. If None, plot will be shown interactively

None
separate bool

Whether to return separate plots for each model

False

Returns:

Type Description
Chart

Altair chart object (combined or separate plots based on separate parameter)

Source code in dnallm/inference/plot.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def plot_scatter(data: dict, show_score: bool = True, ncols: int = 3,
                 width: int = 400, height: int = 400,
                 save_path: str = None, separate: bool = False) -> alt.Chart:
    """Plot scatter plots for regression task evaluation.

    This function creates scatter plots to compare predicted vs. experimental values
    for regression tasks, with optional R² score display.

    Args:
        data: Dictionary containing scatter plot data for each model
        show_score: Whether to show the R² score on the plot
        ncols: Number of columns to arrange the plots
        width: Width of each plot
        height: Height of each plot
        save_path: Path to save the plot. If None, plot will be shown interactively
        separate: Whether to return separate plots for each model

    Returns:
        Altair chart object (combined or separate plots based on separate parameter)
    """
    # Plot bar charts
    pdot = {}
    p_separate = {}
    for n, model in enumerate(data):
        scatter_data = dict(data[model])
        r2 = scatter_data['r2']
        del scatter_data['r2']
        ddot = pd.DataFrame(scatter_data)
        dot = alt.Chart(ddot, title=model).mark_point(filled=True).encode(
            x=alt.X("predicted:Q"),
            y=alt.Y("experiment:Q"),
        ).properties(width=width, height=height)
        if show_score:
            min_x = ddot['predicted'].min()
            max_y = ddot['experiment'].max()
            text = alt.Chart().mark_text(size=14, align="left", baseline="bottom").encode(
                x=alt.datum(min_x + 0.5),
                y=alt.datum(max_y - 0.5),
                text=alt.datum("R\u00b2=" + str(r2))
            )
            p = dot + text
        else:
            p = dot
        if separate:
            p_separate[model] = p.configure_axis(grid=False)
        idx = n // ncols
        if n % ncols == 0:
            pdot[idx] = p
        else:
            pdot[idx] |= p
    # Combine the plots
    for i, p in enumerate(pdot):
        if i == 0:
            pdots = pdot[p]
        else:
            pdots &= pdot[p]
    # Configure the chart
    pdots = pdots.configure_axis(grid=False)
    # Save the plot
    if save_path:
        pdots.save(save_path)
        print(f"Metrics scatter plots saved to {save_path}")
    if separate:
        return p_separate
    else:
        return pdots

prepare_data(metrics, task_type='binary')

Prepare data for plotting various types of visualizations.

This function organizes model metrics data into formats suitable for different plot types: - Bar charts for classification and regression metrics - ROC and PR curves for classification tasks - Scatter plots for regression tasks

Parameters:

Name Type Description Default
metrics dict

Dictionary containing model metrics for different models

required
task_type str

Type of task ('binary', 'multiclass', 'multilabel', 'token', 'regression')

'binary'

Returns:

Type Description
tuple

Tuple containing:

tuple
  • bars_data: Data formatted for bar chart visualization
tuple
  • curves_data/scatter_data: Data formatted for curve or scatter plot visualization

Raises:

Type Description
ValueError

If task type is not supported for plotting

Source code in dnallm/inference/plot.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def prepare_data(metrics: dict, task_type: str = "binary") -> tuple:
    """Prepare data for plotting various types of visualizations.

    This function organizes model metrics data into formats suitable for different plot types:
    - Bar charts for classification and regression metrics
    - ROC and PR curves for classification tasks
    - Scatter plots for regression tasks

    Args:
        metrics: Dictionary containing model metrics for different models
        task_type: Type of task ('binary', 'multiclass', 'multilabel', 'token', 'regression')

    Returns:
        Tuple containing:
        - bars_data: Data formatted for bar chart visualization
        - curves_data/scatter_data: Data formatted for curve or scatter plot visualization

    Raises:
        ValueError: If task type is not supported for plotting
    """
    # Load the data
    bars_data = {'models': []}
    if task_type in ['binary', 'multiclass', 'multilabel', 'token']:
        curves_data = {'ROC': {'models': [], 'fpr': [], 'tpr': []},
                    'PR': {'models': [], 'recall': [], 'precision': []}}
        for model in metrics:
            if model not in bars_data['models']:
                bars_data['models'].append(model)
            for metric in metrics[model]:
                if metric == 'curve':
                    for score in metrics[model][metric]:
                        if score.endswith('pr'):
                            if score == 'fpr':
                                curves_data['ROC']['models'].extend([model] * len(metrics[model][metric][score]))
                            curves_data['ROC'][score].extend(metrics[model][metric][score])
                        else:
                            if score == 'precision':
                                curves_data['PR']['models'].extend([model] * len(metrics[model][metric][score]))
                            curves_data['PR'][score].extend(metrics[model][metric][score])
                else:
                    if metric not in bars_data:
                        bars_data[metric] = []
                    bars_data[metric].append(metrics[model][metric])
        return bars_data, curves_data
    elif task_type == "regression":
        scatter_data = {}
        for model in metrics:
            if model not in bars_data['models']:
                bars_data['models'].append(model)
            scatter_data[model] = {'predicted': [], 'experiment': []}
            for metric in metrics[model]:
                if metric == 'scatter':
                    for score in metrics[model][metric]:
                        scatter_data[model][score].extend(metrics[model][metric][score])
                else:
                    if metric not in bars_data:
                        bars_data[metric] = []
                    bars_data[metric].append(metrics[model][metric])
                    if metric == 'r2':
                        scatter_data[model][metric] = metrics[model][metric]
        return bars_data, scatter_data
    else:
        raise ValueError(f"Unsupport task type {task_type} for ploting")

save_metrics(metrics, output_dir)

Save evaluation metrics to JSON file.

This function saves computed evaluation metrics in JSON format to the specified output directory.

Parameters:

Name Type Description Default
metrics Dict

Dictionary containing metrics to save

required
output_dir Path

Directory path where metrics will be saved

required
Source code in dnallm/inference/predictor.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def save_metrics(metrics: Dict, output_dir: Path) -> None:
    """Save evaluation metrics to JSON file.

    This function saves computed evaluation metrics in JSON format to the specified output directory.

    Args:
        metrics: Dictionary containing metrics to save
        output_dir: Directory path where metrics will be saved
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save metrics
    with open(output_dir / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=4)

save_predictions(predictions, output_dir)

Save predictions to JSON file.

This function saves model predictions in JSON format to the specified output directory.

Parameters:

Name Type Description Default
predictions Dict

Dictionary containing predictions to save

required
output_dir Path

Directory path where predictions will be saved

required
Source code in dnallm/inference/predictor.py
636
637
638
639
640
641
642
643
644
645
646
647
648
649
def save_predictions(predictions: Dict, output_dir: Path) -> None:
    """Save predictions to JSON file.

    This function saves model predictions in JSON format to the specified output directory.

    Args:
        predictions: Dictionary containing predictions to save
        output_dir: Directory path where predictions will be saved
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save predictions
    with open(output_dir / "predictions.json", "w") as f:
        json.dump(predictions, f, indent=4)