Skip to content

inference/predictor API

DNAPredictor

DNA sequence predictor using fine-tuned models.

This class provides functionality for making predictions using DNA language models. It handles model loading, inference, and result processing.

Source code in dnallm/inference/predictor.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
class DNAPredictor:
    """DNA sequence predictor using fine-tuned models.

    This class provides functionality for making predictions using DNA language models.
    It handles model loading, inference, and result processing.
    """

    def __init__(
        self,
        model: any,
        tokenizer: any,
        config: dict
    ):
        """Initialize the predictor.

        Args:
            model: Fine-tuned model instance.
            tokenizer: Tokenizer for the model.
            config: Configuration dictionary containing task settings and inference parameters.
        """

        self.model = model
        self.tokenizer = tokenizer
        self.task_config = config['task']
        self.pred_config = config['inference']
        self.device = self._get_device()
        if model:
            self.model.to(self.device)
            print(f"Use device: {self.device}")
        self.sequences = []
        self.labels = []

    def _get_device(self):
        """Get the appropriate device for model inference.

        Returns:
            torch.device: The device to use for model inference.

        Raises:
            ValueError: If the specified device type is not supported.
        """
        # Get the device type
        device = self.pred_config.device.lower()
        if device == 'cpu':
            return torch.device('cpu')
        elif device in ['cuda', 'nvidia']:
            if not torch.cuda.is_available():
                warnings.warn("CUDA is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('cuda')
        elif device in ['mps', 'apple', 'mac']:
            if not torch.backends.mps.is_available():
                warnings.warn("MPS is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('mps')
        elif device in ['rocm', 'amd']:
            if not torch.cuda.is_available():
                warnings.warn("ROCm is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('cuda')
        elif device == ['tpu', 'xla', 'google']:
            try:
                import torch_xla.core.xla_model as xm
                return torch.device('xla')
            except:
                warnings.warn("TPU is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
        elif device == ['xpu', 'intel']:
            if not torch.xpu.is_available():
                warnings.warn("XPU is not available. Please check your installation. Use CPU instead.")
                return torch.device('cpu')
            else:
                return torch.device('xpu')
        elif device == 'auto':
            if torch.cuda.is_available():
                return torch.device('cuda')
            elif torch.backends.mps.is_available():
                return torch.device('mps')
            elif torch.xpu.is_available():
                return torch.device('xpu')
            else:
                return torch.device('cpu')
        else:
            raise ValueError(f"Unsupported device type: {device}")

    def generate_dataset(self, seq_or_path: Union[str, List[str]], batch_size: int=1,
                         seq_col: str="sequence", label_col: str="labels",
                         sep: str = None, fasta_sep: str = "|",
                         multi_label_sep: Union[str, None] = None,
                         uppercase: bool=False, lowercase: bool=False,
                         keep_seqs: bool=True, do_encode: bool=True) -> tuple:
        """Generate dataset from sequences.

        Args:
            seq_or_path: Single sequence or path to a file containing sequences.
            batch_size: Batch size for DataLoader.
            seq_col: Column name for sequences.
            label_col: Column name for labels.
            sep (str, optional): Delimiter for CSV, TSV, or TXT files.
            fasta_sep (str, optional): Delimiter for FASTA files.
            multi_label_sep (str, optional): Delimiter for multi-label sequences.
            uppercase (bool): Whether to convert sequences to uppercase.
            lowercase (bool): Whether to convert sequences to lowercase.
            keep_seqs: Whether to keep sequences in the dataset.
            do_encode: Whether to encode sequences.

        Returns:
            tuple: A tuple containing:
                - Dataset object
                - DataLoader object

        Raises:
            ValueError: If input is neither a file path nor a list of sequences.
        """
        if isinstance(seq_or_path, str):
            suffix = seq_or_path.split(".")[-1]
            if suffix and os.path.isfile(seq_or_path):
                sequences = []
                dataset = DNADataset.load_local_data(seq_or_path, seq_col=seq_col, label_col=label_col,
                                                     sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                                     tokenizer=self.tokenizer, max_length=self.pred_config.max_length)
            else:
                sequences = [seq_or_path]
        elif isinstance(seq_or_path, list):
            sequences = seq_or_path
        else:
            raise ValueError("Input should be a file path or a list of sequences.")
        if len(sequences) > 0:
            ds = Dataset.from_dict({"sequence": sequences})
            dataset = DNADataset(ds, self.tokenizer, max_length=self.pred_config.max_length)
        # If labels are provided, keep labels
        if keep_seqs:
            self.sequences = dataset.dataset["sequence"]
        # Encode sequences
        if do_encode:
            task_type = self.task_config.task_type
            dataset.encode_sequences(remove_unused_columns=True, task=task_type, uppercase=uppercase, lowercase=lowercase)
        if "labels" in dataset.dataset.features:
            self.labels = dataset.dataset["labels"]
        # Create DataLoader
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=self.pred_config.num_workers
        )

        return dataset, dataloader

    def logits_to_preds(self, logits: list) -> tuple[torch.Tensor, list]:
        """Convert model logits to predictions.

        Args:
            logits: Model output logits.

        Returns:
            tuple: A tuple containing:
                - torch.Tensor: Model predictions
                - list: Human-readable labels

        Raises:
            ValueError: If task type is not supported.
        """
        # Get task type and threshold from config
        task_type = self.task_config.task_type
        threshold = self.task_config.threshold
        label_names = self.task_config.label_names
        # Convert logits to predictions based on task type
        if task_type == "binary":
            probs = torch.softmax(logits, dim=-1)
            preds = (probs[:, 1] > threshold).long()
            labels = [label_names[pred] for pred in preds]
        elif task_type == "multiclass":
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(probs, dim=-1)
            labels = [label_names[pred] for pred in preds]
        elif task_type == "multilabel":
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).long()
            labels = []
            for pred in preds:
                label = [label_names[i] for i in range(len(pred)) if pred[i] == 1]
                labels.append(label)
        elif task_type == "regression":
            preds = logits.squeeze(-1)
            probs = preds
            labels = preds.tolist()
        elif task_type == "token":
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            labels = []
            for pred in preds:
                label = [label_names[pred[i]] for i in range(len(pred))]
                labels.append(label)
        else:
            raise ValueError(f"Unsupported task type: {task_type}")
        return probs, labels

    def format_output(self, predictions: tuple[torch.Tensor, list]) -> dict:
        """Format output predictions.

        Args:
            predictions: Tuple containing predictions.

        Returns:
            dict: Dictionary containing formatted predictions.
        """
        # Get task type from config
        task_type = self.task_config.task_type
        formatted_predictions = {}
        probs, labels = predictions
        probs = probs.numpy().tolist()
        keep_seqs = True if len(self.sequences) else False
        label_names = self.task_config.label_names
        for i, label in enumerate(labels):
            prob = probs[i]
            formatted_predictions[i] = {
                'sequence': self.sequences[i] if keep_seqs else '',
                'label': label,
                'scores': {label_names[j]: p for j, p in enumerate(prob)} if task_type != "token"
                          else [max(x) for x in prob],
            }
        return formatted_predictions

    @torch.no_grad()
    def batch_predict(self, dataloader: DataLoader, do_pred: bool=True,
                      output_hidden_states: bool=False,
                      output_attentions: bool=False) -> tuple[torch.Tensor, list]:
        """Predict for a batch of sequences.

        Args:
            dataloader: DataLoader object containing sequences.
            do_pred: Whether to do prediction.
            output_hidden_states: Whether to output hidden states.
            output_attentions: Whether to output attentions.

        Returns:
            tuple: A tuple containing:
                - torch.Tensor: All logits
                - dict: Predictions dictionary
                - dict: Embeddings dictionary
        """
        # Set model to evaluation mode
        self.model.eval()
        all_logits = []
        # Whether or not to output hidden states
        params = None
        embeddings = {}
        if output_hidden_states:
            import inspect
            sig = inspect.signature(self.model.forward)
            params = sig.parameters
            if 'output_hidden_states' in params:
                self.model.config.output_hidden_states = True
            embeddings['hidden_states'] = None
            embeddings['attention_mask'] = []
            embeddings['labels'] = []
        if output_attentions:
            if not params:
                import inspect
                sig = inspect.signature(self.model.forward)
                params = sig.parameters
            if 'output_attentions' in params:
                self.model.config.output_attentions = True
            embeddings['attentions'] = None
        # Iterate over batches
        for batch in tqdm(dataloader, desc="Predicting"):
            inputs = {k: v.to(self.device) for k, v in batch.items()}
            if self.pred_config.use_fp16:
                self.model = self.model.half()
                with torch.amp.autocast('cuda'):
                    outputs = self.model(**inputs)
            else:
                outputs = self.model(**inputs)
            # Get logits
            logits = outputs.logits.cpu().detach()
            all_logits.append(logits)
            # Get hidden states
            if output_hidden_states:
                hidden_states = [h.cpu().detach() for h in outputs.hidden_states] if hasattr(outputs, 'hidden_states') else None
                if embeddings['hidden_states'] is None:
                    embeddings['hidden_states'] = [[h] for h in hidden_states]
                else:
                    for i, h in enumerate(hidden_states):
                        embeddings['hidden_states'][i].append(h)
                attention_mask = inputs['attention_mask'].cpu().detach() if 'attention_mask' in inputs else None
                embeddings['attention_mask'].append(attention_mask)
                labels = inputs['labels'].cpu().detach() if 'labels' in inputs else None
                embeddings['labels'].append(labels)
            # Get attentions
            if output_attentions:
                attentions = [a.cpu().detach() for a in outputs.attentions] if hasattr(outputs, 'attentions') else None
                if attentions:
                    if embeddings['attentions'] is None:
                        embeddings['attentions'] = [[a] for a in attentions]
                    else:
                        for i, a in enumerate(attentions):
                            embeddings['attentions'][i].append(a)
        # Concatenate logits
        all_logits = torch.cat(all_logits, dim=0)
        if output_hidden_states:
            if embeddings['hidden_states']:
                embeddings['hidden_states'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['hidden_states'])
            if embeddings['attention_mask']:
                embeddings['attention_mask'] = torch.cat(embeddings['attention_mask'], dim=0)
            if embeddings['labels']:
                embeddings['labels'] = torch.cat(embeddings['labels'], dim=0)
        if output_attentions:
            if embeddings['attentions']:
                embeddings['attentions'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['attentions'])
        # Get predictions
        predictions = None
        if do_pred:
            predictions = self.logits_to_preds(all_logits)
            predictions = self.format_output(predictions)
        return all_logits, predictions, embeddings

    def predict_seqs(self, sequences: Union[str, List[str]],
                     evaluate: bool = False,
                     output_hidden_states: bool=False,
                     output_attentions: bool=False,
                     save_to_file: bool = False) -> Union[tuple, dict]:
        """Predict for sequences.

        Args:
            sequences: Single sequence or list of sequences.
            evaluate: Whether to evaluate the predictions.
            output_hidden_states: Whether to output hidden states and attentions.
            output_attentions: Whether to output attentions.
            save_to_file: Whether to save predictions to file.

        Returns:
            Union[tuple, dict]: Either:
                - Dictionary containing predictions
                - Tuple of (predictions, metrics) if evaluate=True
        """
        # Get dataset and dataloader from sequences
        _, dataloader = self.generate_dataset(sequences, batch_size=self.pred_config.batch_size)
        # Do batch prediction
        logits, predictions, embeddings = self.batch_predict(dataloader,
                                                             output_hidden_states=output_hidden_states,
                                                             output_attentions=output_attentions)
        # Keep hidden states
        if output_hidden_states or output_attentions:
            self.embeddings = embeddings
        # Save predictions
        if save_to_file and self.pred_config.output_dir:
            save_predictions(predictions, Path(self.pred_config.output_dir))
        # Do evaluation
        if len(self.labels) == len(logits) and evaluate:
            metrics = self.calculate_metrics(logits, self.labels)
            metrics_save = dict(metrics)
            if 'curve' in metrics_save:
                del metrics_save['curve']
            if 'scatter' in metrics_save:
                del metrics_save['scatter']
            if save_to_file and self.pred_config.output_dir:
                save_metrics(metrics_save, Path(self.pred_config.output_dir))
            return predictions, metrics

        return predictions


    def predict_file(self, file_path: str, evaluate: bool = False,
                     output_hidden_states: bool=False,
                     output_attentions: bool=False,
                     seq_col: str="sequence", label_col: str="labels",
                     sep: str = None, fasta_sep: str = "|",
                     multi_label_sep: Union[str, None] = None,
                     uppercase: bool=False, lowercase: bool=False,
                     save_to_file: bool=False, plot_metrics: bool=False) -> Union[tuple, dict]:
        """Predict from a file containing sequences.

        Args:
            file_path: Path to the file containing sequences.
            evaluate: Whether to evaluate the predictions.
            output_hidden_states: Whether to output hidden states.
            output_attentions: Whether to output attentions.
            seq_col: Column name for sequences.
            label_col: Column name for labels.
            sep (str, optional): Delimiter for CSV, TSV, or TXT files.
            fasta_sep (str, optional): Delimiter for FASTA files.
            multi_label_sep (str, optional): Delimiter for multi-label sequences.
            uppercase (bool): Whether to convert sequences to uppercase.
            lowercase (bool): Whether to convert sequences to lowercase.
            save_to_file: Whether to save predictions to file.
            plot_metrics: Whether to plot metrics.

        Returns:
            Union[tuple, dict]: Either:
                - List of dictionaries containing predictions
                - Tuple of (predictions, metrics) if evaluate=True
        """
        # Get dataset and dataloader from file
        _, dataloader = self.generate_dataset(file_path, seq_col=seq_col, label_col=label_col,
                                              sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                              uppercase=uppercase, lowercase=lowercase,
                                              batch_size=self.pred_config.batch_size)
        # Do batch prediction
        if output_attentions:
            warnings.warn("Cautions: output_attentions may consume a lot of memory.\n")
        logits, predictions, embeddings = self.batch_predict(dataloader,
                                                             output_hidden_states=output_hidden_states,
                                                             output_attentions=output_attentions)
        # Keep hidden states
        if output_hidden_states or output_attentions:
            self.embeddings = embeddings
        # Save predictions
        if save_to_file and self.pred_config.output_dir:
            save_predictions(predictions, Path(self.pred_config.output_dir))
        # Do evaluation
        if len(self.labels) == len(logits) and evaluate:
            metrics = self.calculate_metrics(logits, self.labels, plot=plot_metrics)
            metrics_save = dict(metrics)
            if 'curve' in metrics_save:
                del metrics_save['curve']
            if 'scatter' in metrics_save:
                del metrics_save['scatter']
            if save_to_file and self.pred_config.output_dir:
                save_metrics(metrics, Path(self.pred_config.output_dir))
            # Whether to plot metrics
            if plot_metrics:
                return predictions, metrics
            else:
                return predictions, metrics_save

        return predictions

    def calculate_metrics(self, logits: Union[List, torch.Tensor],
                          labels: Union[List, torch.Tensor], plot: bool=False) -> dict:
        """Calculate evaluation metrics.

        Args:
            logits: Model predictions.
            labels: True labels.
            plot: Whether to plot metrics.

        Returns:
            dict: Dictionary containing evaluation metrics.
        """
        # Calculate metrics based on task type
        compute_metrics = Metrics(self.task_config, plot=plot)
        metrics = compute_metrics((logits, labels))

        return metrics

    def plot_attentions(self, seq_idx: int = 0, layer: int = -1, head: int = -1,
                        width: int = 800, height: int = 800,
                        save_path: Optional[str] = None) -> None:
        """Plot attention map.

        Args:
            seq_idx: Index of the sequence to plot.
            layer: Layer index to plot.
            head: Head index to plot.
            width: Width of the plot.
            height: Height of the plot.
            save_path: Path to save the plot.

        Returns:
            None: If no attention weights are available.
            object: Attention map visualization if available.
        """
        if hasattr(self, 'embeddings'):
            attentions = self.embeddings['attentions']
            if save_path:
                suffix = os.path.splitext(save_path)[-1]
                if suffix:
                    heatmap = save_path.replace(suffix, "_heatmap" + suffix)
                else:
                    heatmap = os.path.join(save_path, "heatmap.pdf")
            else:
                heatmap = None
            # Plot attention map
            attn_map = plot_attention_map(attentions, self.sequences, self.tokenizer,
                                          seq_idx=seq_idx, layer=layer, head=head,
                                          width=width, height=height,
                                          save_path=heatmap)
            return attn_map
        else:
            print("No attention weights available to plot.")

    def plot_hidden_states(self, reducer: str="t-SNE",
                           ncols: int=4, width: int = 300, height: int = 300,
                           save_path: Optional[str] = None) -> None:
        """Embedding visualization.

        Args:
            reducer: Dimensionality reduction method to use.
            ncols: Number of columns in the plot grid.
            width: Width of the plot.
            height: Height of the plot.
            save_path: Path to save the plot.

        Returns:
            None: If no hidden states are available.
            object: Embedding visualization if available.
        """
        if hasattr(self, 'embeddings'):
            hidden_states = self.embeddings['hidden_states']
            attention_mask = torch.unsqueeze(self.embeddings['attention_mask'], dim=-1)
            labels = self.embeddings['labels']
            if save_path:
                suffix = os.path.splitext(save_path)[-1]
                if suffix:
                    embedding = save_path.replace(suffix, "_embedding" + suffix)
                else:
                    embedding = os.path.join(save_path, "embedding.pdf")
            else:
                embedding = None
            # Plot hidden states
            label_names = self.task_config.label_names
            embeddings_vis = plot_embeddings(hidden_states, attention_mask, reducer=reducer,
                                             labels=labels, label_names=label_names,
                                             ncols=ncols, width=width, height=height,
                                             save_path=embedding)
            return embeddings_vis
        else:
            print("No hidden states available to plot.")

__init__(model, tokenizer, config)

Initialize the predictor.

Parameters:

Name Type Description Default
model any

Fine-tuned model instance.

required
tokenizer any

Tokenizer for the model.

required
config dict

Configuration dictionary containing task settings and inference parameters.

required
Source code in dnallm/inference/predictor.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    model: any,
    tokenizer: any,
    config: dict
):
    """Initialize the predictor.

    Args:
        model: Fine-tuned model instance.
        tokenizer: Tokenizer for the model.
        config: Configuration dictionary containing task settings and inference parameters.
    """

    self.model = model
    self.tokenizer = tokenizer
    self.task_config = config['task']
    self.pred_config = config['inference']
    self.device = self._get_device()
    if model:
        self.model.to(self.device)
        print(f"Use device: {self.device}")
    self.sequences = []
    self.labels = []

batch_predict(dataloader, do_pred=True, output_hidden_states=False, output_attentions=False)

Predict for a batch of sequences.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader object containing sequences.

required
do_pred bool

Whether to do prediction.

True
output_hidden_states bool

Whether to output hidden states.

False
output_attentions bool

Whether to output attentions.

False

Returns:

Name Type Description
tuple tuple[Tensor, list]

A tuple containing: - torch.Tensor: All logits - dict: Predictions dictionary - dict: Embeddings dictionary

Source code in dnallm/inference/predictor.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
@torch.no_grad()
def batch_predict(self, dataloader: DataLoader, do_pred: bool=True,
                  output_hidden_states: bool=False,
                  output_attentions: bool=False) -> tuple[torch.Tensor, list]:
    """Predict for a batch of sequences.

    Args:
        dataloader: DataLoader object containing sequences.
        do_pred: Whether to do prediction.
        output_hidden_states: Whether to output hidden states.
        output_attentions: Whether to output attentions.

    Returns:
        tuple: A tuple containing:
            - torch.Tensor: All logits
            - dict: Predictions dictionary
            - dict: Embeddings dictionary
    """
    # Set model to evaluation mode
    self.model.eval()
    all_logits = []
    # Whether or not to output hidden states
    params = None
    embeddings = {}
    if output_hidden_states:
        import inspect
        sig = inspect.signature(self.model.forward)
        params = sig.parameters
        if 'output_hidden_states' in params:
            self.model.config.output_hidden_states = True
        embeddings['hidden_states'] = None
        embeddings['attention_mask'] = []
        embeddings['labels'] = []
    if output_attentions:
        if not params:
            import inspect
            sig = inspect.signature(self.model.forward)
            params = sig.parameters
        if 'output_attentions' in params:
            self.model.config.output_attentions = True
        embeddings['attentions'] = None
    # Iterate over batches
    for batch in tqdm(dataloader, desc="Predicting"):
        inputs = {k: v.to(self.device) for k, v in batch.items()}
        if self.pred_config.use_fp16:
            self.model = self.model.half()
            with torch.amp.autocast('cuda'):
                outputs = self.model(**inputs)
        else:
            outputs = self.model(**inputs)
        # Get logits
        logits = outputs.logits.cpu().detach()
        all_logits.append(logits)
        # Get hidden states
        if output_hidden_states:
            hidden_states = [h.cpu().detach() for h in outputs.hidden_states] if hasattr(outputs, 'hidden_states') else None
            if embeddings['hidden_states'] is None:
                embeddings['hidden_states'] = [[h] for h in hidden_states]
            else:
                for i, h in enumerate(hidden_states):
                    embeddings['hidden_states'][i].append(h)
            attention_mask = inputs['attention_mask'].cpu().detach() if 'attention_mask' in inputs else None
            embeddings['attention_mask'].append(attention_mask)
            labels = inputs['labels'].cpu().detach() if 'labels' in inputs else None
            embeddings['labels'].append(labels)
        # Get attentions
        if output_attentions:
            attentions = [a.cpu().detach() for a in outputs.attentions] if hasattr(outputs, 'attentions') else None
            if attentions:
                if embeddings['attentions'] is None:
                    embeddings['attentions'] = [[a] for a in attentions]
                else:
                    for i, a in enumerate(attentions):
                        embeddings['attentions'][i].append(a)
    # Concatenate logits
    all_logits = torch.cat(all_logits, dim=0)
    if output_hidden_states:
        if embeddings['hidden_states']:
            embeddings['hidden_states'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['hidden_states'])
        if embeddings['attention_mask']:
            embeddings['attention_mask'] = torch.cat(embeddings['attention_mask'], dim=0)
        if embeddings['labels']:
            embeddings['labels'] = torch.cat(embeddings['labels'], dim=0)
    if output_attentions:
        if embeddings['attentions']:
            embeddings['attentions'] = tuple(torch.cat(lst, dim=0) for lst in embeddings['attentions'])
    # Get predictions
    predictions = None
    if do_pred:
        predictions = self.logits_to_preds(all_logits)
        predictions = self.format_output(predictions)
    return all_logits, predictions, embeddings

calculate_metrics(logits, labels, plot=False)

Calculate evaluation metrics.

Parameters:

Name Type Description Default
logits Union[List, Tensor]

Model predictions.

required
labels Union[List, Tensor]

True labels.

required
plot bool

Whether to plot metrics.

False

Returns:

Name Type Description
dict dict

Dictionary containing evaluation metrics.

Source code in dnallm/inference/predictor.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def calculate_metrics(self, logits: Union[List, torch.Tensor],
                      labels: Union[List, torch.Tensor], plot: bool=False) -> dict:
    """Calculate evaluation metrics.

    Args:
        logits: Model predictions.
        labels: True labels.
        plot: Whether to plot metrics.

    Returns:
        dict: Dictionary containing evaluation metrics.
    """
    # Calculate metrics based on task type
    compute_metrics = Metrics(self.task_config, plot=plot)
    metrics = compute_metrics((logits, labels))

    return metrics

format_output(predictions)

Format output predictions.

Parameters:

Name Type Description Default
predictions tuple[Tensor, list]

Tuple containing predictions.

required

Returns:

Name Type Description
dict dict

Dictionary containing formatted predictions.

Source code in dnallm/inference/predictor.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def format_output(self, predictions: tuple[torch.Tensor, list]) -> dict:
    """Format output predictions.

    Args:
        predictions: Tuple containing predictions.

    Returns:
        dict: Dictionary containing formatted predictions.
    """
    # Get task type from config
    task_type = self.task_config.task_type
    formatted_predictions = {}
    probs, labels = predictions
    probs = probs.numpy().tolist()
    keep_seqs = True if len(self.sequences) else False
    label_names = self.task_config.label_names
    for i, label in enumerate(labels):
        prob = probs[i]
        formatted_predictions[i] = {
            'sequence': self.sequences[i] if keep_seqs else '',
            'label': label,
            'scores': {label_names[j]: p for j, p in enumerate(prob)} if task_type != "token"
                      else [max(x) for x in prob],
        }
    return formatted_predictions

generate_dataset(seq_or_path, batch_size=1, seq_col='sequence', label_col='labels', sep=None, fasta_sep='|', multi_label_sep=None, uppercase=False, lowercase=False, keep_seqs=True, do_encode=True)

Generate dataset from sequences.

Parameters:

Name Type Description Default
seq_or_path Union[str, List[str]]

Single sequence or path to a file containing sequences.

required
batch_size int

Batch size for DataLoader.

1
seq_col str

Column name for sequences.

'sequence'
label_col str

Column name for labels.

'labels'
sep str

Delimiter for CSV, TSV, or TXT files.

None
fasta_sep str

Delimiter for FASTA files.

'|'
multi_label_sep str

Delimiter for multi-label sequences.

None
uppercase bool

Whether to convert sequences to uppercase.

False
lowercase bool

Whether to convert sequences to lowercase.

False
keep_seqs bool

Whether to keep sequences in the dataset.

True
do_encode bool

Whether to encode sequences.

True

Returns:

Name Type Description
tuple tuple

A tuple containing: - Dataset object - DataLoader object

Raises:

Type Description
ValueError

If input is neither a file path nor a list of sequences.

Source code in dnallm/inference/predictor.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def generate_dataset(self, seq_or_path: Union[str, List[str]], batch_size: int=1,
                     seq_col: str="sequence", label_col: str="labels",
                     sep: str = None, fasta_sep: str = "|",
                     multi_label_sep: Union[str, None] = None,
                     uppercase: bool=False, lowercase: bool=False,
                     keep_seqs: bool=True, do_encode: bool=True) -> tuple:
    """Generate dataset from sequences.

    Args:
        seq_or_path: Single sequence or path to a file containing sequences.
        batch_size: Batch size for DataLoader.
        seq_col: Column name for sequences.
        label_col: Column name for labels.
        sep (str, optional): Delimiter for CSV, TSV, or TXT files.
        fasta_sep (str, optional): Delimiter for FASTA files.
        multi_label_sep (str, optional): Delimiter for multi-label sequences.
        uppercase (bool): Whether to convert sequences to uppercase.
        lowercase (bool): Whether to convert sequences to lowercase.
        keep_seqs: Whether to keep sequences in the dataset.
        do_encode: Whether to encode sequences.

    Returns:
        tuple: A tuple containing:
            - Dataset object
            - DataLoader object

    Raises:
        ValueError: If input is neither a file path nor a list of sequences.
    """
    if isinstance(seq_or_path, str):
        suffix = seq_or_path.split(".")[-1]
        if suffix and os.path.isfile(seq_or_path):
            sequences = []
            dataset = DNADataset.load_local_data(seq_or_path, seq_col=seq_col, label_col=label_col,
                                                 sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                                 tokenizer=self.tokenizer, max_length=self.pred_config.max_length)
        else:
            sequences = [seq_or_path]
    elif isinstance(seq_or_path, list):
        sequences = seq_or_path
    else:
        raise ValueError("Input should be a file path or a list of sequences.")
    if len(sequences) > 0:
        ds = Dataset.from_dict({"sequence": sequences})
        dataset = DNADataset(ds, self.tokenizer, max_length=self.pred_config.max_length)
    # If labels are provided, keep labels
    if keep_seqs:
        self.sequences = dataset.dataset["sequence"]
    # Encode sequences
    if do_encode:
        task_type = self.task_config.task_type
        dataset.encode_sequences(remove_unused_columns=True, task=task_type, uppercase=uppercase, lowercase=lowercase)
    if "labels" in dataset.dataset.features:
        self.labels = dataset.dataset["labels"]
    # Create DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=self.pred_config.num_workers
    )

    return dataset, dataloader

logits_to_preds(logits)

Convert model logits to predictions.

Parameters:

Name Type Description Default
logits list

Model output logits.

required

Returns:

Name Type Description
tuple tuple[Tensor, list]

A tuple containing: - torch.Tensor: Model predictions - list: Human-readable labels

Raises:

Type Description
ValueError

If task type is not supported.

Source code in dnallm/inference/predictor.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def logits_to_preds(self, logits: list) -> tuple[torch.Tensor, list]:
    """Convert model logits to predictions.

    Args:
        logits: Model output logits.

    Returns:
        tuple: A tuple containing:
            - torch.Tensor: Model predictions
            - list: Human-readable labels

    Raises:
        ValueError: If task type is not supported.
    """
    # Get task type and threshold from config
    task_type = self.task_config.task_type
    threshold = self.task_config.threshold
    label_names = self.task_config.label_names
    # Convert logits to predictions based on task type
    if task_type == "binary":
        probs = torch.softmax(logits, dim=-1)
        preds = (probs[:, 1] > threshold).long()
        labels = [label_names[pred] for pred in preds]
    elif task_type == "multiclass":
        probs = torch.softmax(logits, dim=-1)
        preds = torch.argmax(probs, dim=-1)
        labels = [label_names[pred] for pred in preds]
    elif task_type == "multilabel":
        probs = torch.sigmoid(logits)
        preds = (probs > threshold).long()
        labels = []
        for pred in preds:
            label = [label_names[i] for i in range(len(pred)) if pred[i] == 1]
            labels.append(label)
    elif task_type == "regression":
        preds = logits.squeeze(-1)
        probs = preds
        labels = preds.tolist()
    elif task_type == "token":
        probs = torch.softmax(logits, dim=-1)
        preds = torch.argmax(logits, dim=-1)
        labels = []
        for pred in preds:
            label = [label_names[pred[i]] for i in range(len(pred))]
            labels.append(label)
    else:
        raise ValueError(f"Unsupported task type: {task_type}")
    return probs, labels

plot_attentions(seq_idx=0, layer=-1, head=-1, width=800, height=800, save_path=None)

Plot attention map.

Parameters:

Name Type Description Default
seq_idx int

Index of the sequence to plot.

0
layer int

Layer index to plot.

-1
head int

Head index to plot.

-1
width int

Width of the plot.

800
height int

Height of the plot.

800
save_path Optional[str]

Path to save the plot.

None

Returns:

Name Type Description
None None

If no attention weights are available.

object None

Attention map visualization if available.

Source code in dnallm/inference/predictor.py
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def plot_attentions(self, seq_idx: int = 0, layer: int = -1, head: int = -1,
                    width: int = 800, height: int = 800,
                    save_path: Optional[str] = None) -> None:
    """Plot attention map.

    Args:
        seq_idx: Index of the sequence to plot.
        layer: Layer index to plot.
        head: Head index to plot.
        width: Width of the plot.
        height: Height of the plot.
        save_path: Path to save the plot.

    Returns:
        None: If no attention weights are available.
        object: Attention map visualization if available.
    """
    if hasattr(self, 'embeddings'):
        attentions = self.embeddings['attentions']
        if save_path:
            suffix = os.path.splitext(save_path)[-1]
            if suffix:
                heatmap = save_path.replace(suffix, "_heatmap" + suffix)
            else:
                heatmap = os.path.join(save_path, "heatmap.pdf")
        else:
            heatmap = None
        # Plot attention map
        attn_map = plot_attention_map(attentions, self.sequences, self.tokenizer,
                                      seq_idx=seq_idx, layer=layer, head=head,
                                      width=width, height=height,
                                      save_path=heatmap)
        return attn_map
    else:
        print("No attention weights available to plot.")

plot_hidden_states(reducer='t-SNE', ncols=4, width=300, height=300, save_path=None)

Embedding visualization.

Parameters:

Name Type Description Default
reducer str

Dimensionality reduction method to use.

't-SNE'
ncols int

Number of columns in the plot grid.

4
width int

Width of the plot.

300
height int

Height of the plot.

300
save_path Optional[str]

Path to save the plot.

None

Returns:

Name Type Description
None None

If no hidden states are available.

object None

Embedding visualization if available.

Source code in dnallm/inference/predictor.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def plot_hidden_states(self, reducer: str="t-SNE",
                       ncols: int=4, width: int = 300, height: int = 300,
                       save_path: Optional[str] = None) -> None:
    """Embedding visualization.

    Args:
        reducer: Dimensionality reduction method to use.
        ncols: Number of columns in the plot grid.
        width: Width of the plot.
        height: Height of the plot.
        save_path: Path to save the plot.

    Returns:
        None: If no hidden states are available.
        object: Embedding visualization if available.
    """
    if hasattr(self, 'embeddings'):
        hidden_states = self.embeddings['hidden_states']
        attention_mask = torch.unsqueeze(self.embeddings['attention_mask'], dim=-1)
        labels = self.embeddings['labels']
        if save_path:
            suffix = os.path.splitext(save_path)[-1]
            if suffix:
                embedding = save_path.replace(suffix, "_embedding" + suffix)
            else:
                embedding = os.path.join(save_path, "embedding.pdf")
        else:
            embedding = None
        # Plot hidden states
        label_names = self.task_config.label_names
        embeddings_vis = plot_embeddings(hidden_states, attention_mask, reducer=reducer,
                                         labels=labels, label_names=label_names,
                                         ncols=ncols, width=width, height=height,
                                         save_path=embedding)
        return embeddings_vis
    else:
        print("No hidden states available to plot.")

predict_file(file_path, evaluate=False, output_hidden_states=False, output_attentions=False, seq_col='sequence', label_col='labels', sep=None, fasta_sep='|', multi_label_sep=None, uppercase=False, lowercase=False, save_to_file=False, plot_metrics=False)

Predict from a file containing sequences.

Parameters:

Name Type Description Default
file_path str

Path to the file containing sequences.

required
evaluate bool

Whether to evaluate the predictions.

False
output_hidden_states bool

Whether to output hidden states.

False
output_attentions bool

Whether to output attentions.

False
seq_col str

Column name for sequences.

'sequence'
label_col str

Column name for labels.

'labels'
sep str

Delimiter for CSV, TSV, or TXT files.

None
fasta_sep str

Delimiter for FASTA files.

'|'
multi_label_sep str

Delimiter for multi-label sequences.

None
uppercase bool

Whether to convert sequences to uppercase.

False
lowercase bool

Whether to convert sequences to lowercase.

False
save_to_file bool

Whether to save predictions to file.

False
plot_metrics bool

Whether to plot metrics.

False

Returns:

Type Description
Union[tuple, dict]

Union[tuple, dict]: Either: - List of dictionaries containing predictions - Tuple of (predictions, metrics) if evaluate=True

Source code in dnallm/inference/predictor.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def predict_file(self, file_path: str, evaluate: bool = False,
                 output_hidden_states: bool=False,
                 output_attentions: bool=False,
                 seq_col: str="sequence", label_col: str="labels",
                 sep: str = None, fasta_sep: str = "|",
                 multi_label_sep: Union[str, None] = None,
                 uppercase: bool=False, lowercase: bool=False,
                 save_to_file: bool=False, plot_metrics: bool=False) -> Union[tuple, dict]:
    """Predict from a file containing sequences.

    Args:
        file_path: Path to the file containing sequences.
        evaluate: Whether to evaluate the predictions.
        output_hidden_states: Whether to output hidden states.
        output_attentions: Whether to output attentions.
        seq_col: Column name for sequences.
        label_col: Column name for labels.
        sep (str, optional): Delimiter for CSV, TSV, or TXT files.
        fasta_sep (str, optional): Delimiter for FASTA files.
        multi_label_sep (str, optional): Delimiter for multi-label sequences.
        uppercase (bool): Whether to convert sequences to uppercase.
        lowercase (bool): Whether to convert sequences to lowercase.
        save_to_file: Whether to save predictions to file.
        plot_metrics: Whether to plot metrics.

    Returns:
        Union[tuple, dict]: Either:
            - List of dictionaries containing predictions
            - Tuple of (predictions, metrics) if evaluate=True
    """
    # Get dataset and dataloader from file
    _, dataloader = self.generate_dataset(file_path, seq_col=seq_col, label_col=label_col,
                                          sep=sep, fasta_sep=fasta_sep, multi_label_sep=multi_label_sep,
                                          uppercase=uppercase, lowercase=lowercase,
                                          batch_size=self.pred_config.batch_size)
    # Do batch prediction
    if output_attentions:
        warnings.warn("Cautions: output_attentions may consume a lot of memory.\n")
    logits, predictions, embeddings = self.batch_predict(dataloader,
                                                         output_hidden_states=output_hidden_states,
                                                         output_attentions=output_attentions)
    # Keep hidden states
    if output_hidden_states or output_attentions:
        self.embeddings = embeddings
    # Save predictions
    if save_to_file and self.pred_config.output_dir:
        save_predictions(predictions, Path(self.pred_config.output_dir))
    # Do evaluation
    if len(self.labels) == len(logits) and evaluate:
        metrics = self.calculate_metrics(logits, self.labels, plot=plot_metrics)
        metrics_save = dict(metrics)
        if 'curve' in metrics_save:
            del metrics_save['curve']
        if 'scatter' in metrics_save:
            del metrics_save['scatter']
        if save_to_file and self.pred_config.output_dir:
            save_metrics(metrics, Path(self.pred_config.output_dir))
        # Whether to plot metrics
        if plot_metrics:
            return predictions, metrics
        else:
            return predictions, metrics_save

    return predictions

predict_seqs(sequences, evaluate=False, output_hidden_states=False, output_attentions=False, save_to_file=False)

Predict for sequences.

Parameters:

Name Type Description Default
sequences Union[str, List[str]]

Single sequence or list of sequences.

required
evaluate bool

Whether to evaluate the predictions.

False
output_hidden_states bool

Whether to output hidden states and attentions.

False
output_attentions bool

Whether to output attentions.

False
save_to_file bool

Whether to save predictions to file.

False

Returns:

Type Description
Union[tuple, dict]

Union[tuple, dict]: Either: - Dictionary containing predictions - Tuple of (predictions, metrics) if evaluate=True

Source code in dnallm/inference/predictor.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def predict_seqs(self, sequences: Union[str, List[str]],
                 evaluate: bool = False,
                 output_hidden_states: bool=False,
                 output_attentions: bool=False,
                 save_to_file: bool = False) -> Union[tuple, dict]:
    """Predict for sequences.

    Args:
        sequences: Single sequence or list of sequences.
        evaluate: Whether to evaluate the predictions.
        output_hidden_states: Whether to output hidden states and attentions.
        output_attentions: Whether to output attentions.
        save_to_file: Whether to save predictions to file.

    Returns:
        Union[tuple, dict]: Either:
            - Dictionary containing predictions
            - Tuple of (predictions, metrics) if evaluate=True
    """
    # Get dataset and dataloader from sequences
    _, dataloader = self.generate_dataset(sequences, batch_size=self.pred_config.batch_size)
    # Do batch prediction
    logits, predictions, embeddings = self.batch_predict(dataloader,
                                                         output_hidden_states=output_hidden_states,
                                                         output_attentions=output_attentions)
    # Keep hidden states
    if output_hidden_states or output_attentions:
        self.embeddings = embeddings
    # Save predictions
    if save_to_file and self.pred_config.output_dir:
        save_predictions(predictions, Path(self.pred_config.output_dir))
    # Do evaluation
    if len(self.labels) == len(logits) and evaluate:
        metrics = self.calculate_metrics(logits, self.labels)
        metrics_save = dict(metrics)
        if 'curve' in metrics_save:
            del metrics_save['curve']
        if 'scatter' in metrics_save:
            del metrics_save['scatter']
        if save_to_file and self.pred_config.output_dir:
            save_metrics(metrics_save, Path(self.pred_config.output_dir))
        return predictions, metrics

    return predictions

plot_attention_map(attentions, sequences, tokenizer, seq_idx=0, layer=-1, head=-1, width=800, height=800, save_path=None)

Plot attention map. Args: attentions (tuple): Tuple containing attention weights. sequences (list): List of sequences. tokenizer: Tokenizer object. seq_idx (int): Index of the sequence to plot. layer (int): Layer index. head (int): Head index. width (int): Width of the plot. height (int): Height of the plot. save_path (str): Path to save the plot. Returns: attn_map: Altair chart object.

Source code in dnallm/inference/plot.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def plot_attention_map(attentions: Union[tuple, list], sequences: list, tokenizer,
                       seq_idx: int=0, layer: int=-1, head: int=-1,
                       width: int = 800, height: int = 800,
                       save_path: str = None) -> alt.Chart:
    """
    Plot attention map.
    Args:
        attentions (tuple): Tuple containing attention weights.
        sequences (list): List of sequences.
        tokenizer: Tokenizer object.
        seq_idx (int): Index of the sequence to plot.
        layer (int): Layer index.
        head (int): Head index.
        width (int): Width of the plot.
        height (int): Height of the plot.
        save_path (str): Path to save the plot.
    Returns:
        attn_map: Altair chart object.
    """
    # Plot attention map
    attn_layer = attentions[layer].numpy()
    attn_head = attn_layer[seq_idx][head]
    # Get the tokens
    seq = sequences[seq_idx]
    tokens_id = tokenizer.encode(seq)
    try:
        tokens = tokenizer.convert_ids_to_tokens(tokens_id)
    except:
        tokens = tokenizer.decode(seq).split()
    # Create a DataFrame for the attention map
    num_tokens = len(tokens)
    flen = len(str(num_tokens))
    df = {"token1": [], 'token2': [], 'attn': []}
    for i, t1 in enumerate(tokens):
        for j, t2 in enumerate(tokens):
            df["token1"].append(str(i).zfill(flen)+t1)
            df["token2"].append(str(num_tokens-j).zfill(flen)+t2)
            df["attn"].append(attn_head[i][j])
    source = pd.DataFrame(df)
    # Enable VegaFusion for Altair
    alt.data_transformers.enable("vegafusion")
    # Plot the attention map
    attn_map = alt.Chart(source).mark_rect().encode(
        x=alt.X('token1:O', axis=alt.Axis(
                    labelExpr = f"substring(datum.value, {flen}, 100)",
                    labelAngle=-45,
                    )
                ).title(None),
        y=alt.Y('token2:O', axis=alt.Axis(
                    labelExpr = f"substring(datum.value, {flen}, 100)",
                    labelAngle=0,
                    )
                ).title(None),
        color=alt.Color('attn:Q').scale(scheme='viridis'),
    ).properties(
        width=width,
        height=height
    ).configure_axis(grid=False)
    # Save the plot
    if save_path:
        attn_map.save(save_path)
        print(f"Attention map saved to {save_path}")
    return attn_map

plot_bars(data, show_score=True, ncols=3, width=200, height=50, bar_width=30, domain=(0.0, 1.0), save_path=None, separate=False)

Plot bar chart. Args: data (dict): Data to be plotted. show_score (bool): Whether to show the score on the plot. ncols (int): Number of columns in the plot. width (int): Width of the plot. height (int): Height of the plot. bar_width (int): Width of the bars in the plot. save_path (str): Path to save the plot. Returns: pbars: Altair chart object.

Source code in dnallm/inference/plot.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def plot_bars(data: dict, show_score: bool = True, ncols: int = 3,
              width: int = 200, height: int = 50, bar_width: int = 30,
              domain: Union[tuple, list]= (0.0, 1.0),
              save_path: str = None, separate: bool=False) -> alt.Chart:
    """
    Plot bar chart.
    Args:
        data (dict): Data to be plotted.
        show_score (bool): Whether to show the score on the plot.
        ncols (int): Number of columns in the plot.
        width (int): Width of the plot.
        height (int): Height of the plot.
        bar_width (int): Width of the bars in the plot.
        save_path (str): Path to save the plot.
    Returns:
        pbars: Altair chart object.
    """
    # Plot bar charts
    dbar = pd.DataFrame(data)
    pbar = {}
    p_separate = {}
    for n, metric in enumerate([x for x in data if x != 'models']):
        if metric in ['mae', 'mse']:
            domain_use = [0, dbar[metric].max()*1.1]
        else:
            domain_use = domain
        bar = alt.Chart(dbar).mark_bar(size=bar_width).encode(
            x=alt.X(metric + ":Q").scale(domain=domain_use),
            y=alt.Y("models").title(None),
            color=alt.Color('models').legend(None),
        ).properties(width=width, height=height*len(dbar['models']))
        if show_score:
            text = alt.Chart(dbar).mark_text(
                dx=-10,
                color='white',
                baseline='middle',
                align='right').encode(
                    x=alt.X(metric + ":Q"),
                    y=alt.Y("models").title(None),
                    text=alt.Text(metric, format='.3f')
                    )
            p = bar + text
        else:
            p = bar
        if separate:
            p_separate[metric] = p.configure_axis(grid=False)
        idx = n // ncols
        if n % ncols == 0:
            pbar[idx] = p
        else:
            pbar[idx] |= p
    # Combine the plots
    for i, p in enumerate(pbar):
        if i == 0:
            pbars = pbar[p]
        else:
            pbars &= pbar[p]
    # Configure the chart
    pbars = pbars.configure_axis(grid=False)
    # Save the plot
    if save_path:
        pbars.save(save_path)
        print(f"Metrics bar charts saved to {save_path}")
    if separate:
        return p_separate
    else:
        return pbars

plot_curve(data, show_score=True, width=400, height=400, save_path=None, separate=False)

Plot curve chart. Args: data (dict): Data to be plotted. show_score (bool): Whether to show the score on the plot. width (int): Width of the plot. height (int): Height of the plot. save_path (str): Path to save the plot. Returns: plines: Altair chart object.

Source code in dnallm/inference/plot.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def plot_curve(data: dict, show_score: bool = True,
               width: int = 400, height: int = 400,
               save_path: str = None, separate: bool=False) -> alt.Chart:
    """
    Plot curve chart.
    Args:
        data (dict): Data to be plotted.
        show_score (bool): Whether to show the score on the plot.
        width (int): Width of the plot.
        height (int): Height of the plot.
        save_path (str): Path to save the plot.
    Returns:
        plines: Altair chart object.
    """
    # Plot curves
    pline = {}
    p_separate = {}
    # Plot ROC curve
    roc_data = pd.DataFrame(data['ROC'])
    pline[0] = alt.Chart(roc_data).mark_line().encode(
        x=alt.X("fpr").scale(domain=(0.0, 1.0)),
        y=alt.Y("tpr").scale(domain=(0.0, 1.0)),
        color="models",
    ).properties(width=width, height=height)
    if separate:
        p_separate['ROC'] = pline[0]
    # Plot PR curve
    pr_data = pd.DataFrame(data['PR'])
    pline[1] = alt.Chart(pr_data).mark_line().encode(
        x=alt.X("recall").scale(domain=(0.0, 1.0)),
        y=alt.Y("precision").scale(domain=(0.0, 1.0)),
        color="models",
    ).properties(width=width, height=height)
    if separate:
        p_separate['PR'] = pline[1]
    # Combine the plots
    for i, p in enumerate(pline):
        if i == 0:
            plines = pline[i]
        else:
            plines |= pline[i]
    # Configure the chart
    plines = plines.configure_axis(grid=False)
    # Save the plot
    if save_path:
        plines.save(save_path)
        print(f"ROC curves saved to {save_path}")
    if separate:
        return p_separate
    else:
        return plines

plot_embeddings(hidden_states, attention_mask, reducer='t-SNE', labels=None, label_names=None, ncols=4, width=300, height=300, save_path=None, separate=False)

Visualize embeddings

Parameters:

Name Type Description Default
hidden_states tuple

Tuple containing hidden states.

required
attention_mask tuple

Tuple containing attention mask.

required
reducer str

Dimensionality reduction method. Options: PCA, t-SNE, UMAP.

't-SNE'
labels list

List of labels for the data points.

None
label_names list

List of label names.

None
ncols int

Number of columns in the plot.

4
width int

Width of the plot.

300
height int

Height of the plot.

300
save_path str

Path to save the plot.

None
separate bool

Whether to return separate plots for each layer.

False

Returns: pdots: Altair chart object.

Source code in dnallm/inference/plot.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def plot_embeddings(hidden_states: Union[tuple, list], attention_mask: Union[tuple, list], reducer: str="t-SNE",
                    labels: Union[tuple, list]=None, label_names: Union[str, list]=None,
                    ncols: int=4, width: int=300, height: int=300,
                    save_path: str = None, separate: bool=False) -> alt.Chart:
    '''
    Visualize embeddings

    Args:
        hidden_states (tuple): Tuple containing hidden states.
        attention_mask (tuple): Tuple containing attention mask.
        reducer (str): Dimensionality reduction method. Options: PCA, t-SNE, UMAP.
        labels (list): List of labels for the data points.
        label_names (list): List of label names.
        ncols (int): Number of columns in the plot.
        width (int): Width of the plot.
        height (int): Height of the plot.
        save_path (str): Path to save the plot.
        separate (bool): Whether to return separate plots for each layer.
    Returns:
        pdots: Altair chart object.
    '''
    import torch
    if reducer.lower() == "pca":
        from sklearn.decomposition import PCA
        dim_reducer = PCA(n_components=2)
    elif reducer.lower() == "t-sne":
        from sklearn.manifold import TSNE
        dim_reducer = TSNE(n_components=2)
    elif reducer.lower() == "umap":
        from umap import UMAP
        dim_reducer = UMAP(n_components=2)
    else:
        raise("Unsupported dim reducer, please try PCA, t-SNE or UMAP.")

    pdot = {}
    p_separate = {}
    for i, hidden in enumerate(hidden_states):
        embeddings = hidden.numpy()
        mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2) / torch.sum(attention_mask, axis=1)
        layer_dim_reduced_vectors = dim_reducer.fit_transform(mean_sequence_embeddings.numpy())
        if len(labels) == 0:
            labels = ["Uncategorized"] * layer_dim_reduced_vectors.shape[0]
        df = {
            'Dimension 1': layer_dim_reduced_vectors[:,0],
            'Dimension 2': layer_dim_reduced_vectors[:,1],
            'labels': [label_names[int(i)] for i in labels]
        }
        source = pd.DataFrame(df)
        dot = alt.Chart(source, title=f"Layer {i+1}").mark_point(filled=True).encode(
            x=alt.X("Dimension 1:Q"),
            y=alt.Y("Dimension 2:Q"),
            color=alt.Color("labels:N", legend=alt.Legend(title="Labels")),
        ).properties(width=width, height=height)
        if separate:
            p_separate[f"Layer{i+1}"] = dot.configure_axis(grid=False)
        idx = i // ncols
        if i % ncols == 0:
            pdot[idx] = dot
        else:
            pdot[idx] |= dot
    # Combine the plots
    for i, p in enumerate(pdot):
        if i == 0:
            pdots = pdot[p]
        else:
            pdots &= pdot[p]
    # Configure the chart
    pdots = pdots.configure_axis(grid=False)
    # Save the plot
    if save_path:
        pdots.save(save_path)
        print(f"Embeddings visualization saved to {save_path}")
    if separate:
        return p_separate
    else:
        return pdots

plot_muts(data, show_score=False, width=None, height=100, save_path=None)

Visualize mutation effects

Source code in dnallm/inference/plot.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def plot_muts(data: dict, show_score: bool = False,
              width: int = None, height: int = 100,
              save_path: str = None) -> alt.Chart:
    '''
    Visualize mutation effects
    '''
    # Create dataframe
    seqlen = len(data['raw']['sequence'])
    flen = len(str(seqlen))
    mut_list = [x for x in data.keys() if x != 'raw']
    raw_bases = [base for base in data['raw']['sequence']]
    dheat = {"base": [], 'mut': [], 'score': []}
    dline = {"x": [str(i).zfill(flen)+x for i,x in enumerate(raw_bases)] * 2,
             "score": [0.0]*seqlen*2,
             "type": ["gain"]*seqlen + ["loss"]*seqlen}
    dbar = {"x": [], "score": [], "base": []}
    # Iterate through mutations
    for i, base1 in enumerate(raw_bases):
        ref = "mut_" + str(i) + "_" + base1 + "_" + base1
        # replacement
        mut_prefix = "mut_" + str(i) + "_" + base1 + "_"
        maxabs = 0.0
        maxscore = 0.0
        maxabs_index = base1
        for mut in sorted([x for x in mut_list if x.startswith(mut_prefix)] + [ref]):
            if mut in data:
                # for heatmap
                base2 = mut.split("_")[-1]
                score = data[mut]['score']
                dheat["base"].append(str(i).zfill(flen)+base1)
                dheat["mut"].append(base2)
                dheat["score"].append(score)
                # for line
                if score >= 0:
                    dline["score"][i] += score
                elif score < 0:
                    dline["score"][i+seqlen] -= score
                # for bar chart
                if abs(score) > maxabs:
                    maxabs = abs(score)
                    maxscore = score
                    maxabs_index = base2
            else:
                dheat["base"].append(str(i).zfill(flen)+base1)
                dheat["mut"].append(base1)
                dheat["score"].append(0.0)
        # for bar chart
        dbar["x"].append(str(i).zfill(flen)+base1)
        dbar["score"].append(maxscore)
        dbar["base"].append(maxabs_index)
        # deletion
        del_prefix = "del_" + str(i) + "_"
        for mut in [x for x in mut_list if x.startswith(del_prefix)]:
            base2 = "del_" + mut.split("_")[-1]
            score = data[mut]['score']
            dheat["base"].append(str(i).zfill(flen)+base1)
            dheat["mut"].append(base2)
            dheat["score"].append(score)
        # insertion
        ins_prefix = "ins_" + str(i) + "_"
        for mut in [x for x in mut_list if x.startswith(ins_prefix)]:
            base2 = "ins_" + mut.split("_")[-1]
            score = data[mut]['score']
            dheat["base"].append(str(i).zfill(flen)+base1)
            dheat["mut"].append(base2)
            dheat["score"].append(score)
    # Set color domain and range
    domain1_min = min([data[mut]['score'] for mut in data])
    domain1_max = max([data[mut]['score'] for mut in data])
    domain1 = [-max([abs(domain1_min), abs(domain1_max)]),
               0.0,
               max([abs(domain1_min), abs(domain1_max)])]
    range1_ = ['#2166ac', '#f7f7f7', '#b2182b']
    domain2 = sorted([x for x in set(dbar['base'])])
    range2_ = ["#33a02c", "#e31a1c", "#1f78b4", "#ff7f00", "#cab2d6"][:len(domain2)]
    # Enable VegaFusion for Altair
    alt.data_transformers.enable("vegafusion")
    # Plot the heatmap
    if width is None:
        width = int(height * len(raw_bases) / len(set(dheat['mut'])))
    if dheat['base']:
        pheat = alt.Chart(pd.DataFrame(dheat)).mark_rect().encode(
            x=alt.X('base:O', axis=alt.Axis(
                    labelExpr = f"substring(datum.value, {flen}, {flen}+1)",
                    labelAngle=0,
                    )
                ).title(None),
            y=alt.Y('mut:O').title("mutation"),
            color=alt.Color('score:Q').scale(domain=domain1, range=range1_),
        ).properties(
            width=width, height=height
        )
        # Plot gain and loss
        pline = alt.Chart(pd.DataFrame(dline)).mark_line().encode(
            x=alt.X('x:O').title(None).axis(labels=False),
            y=alt.Y('score:Q'),
            color=alt.Color('type:N').scale(
                domain=['gain', 'loss'], range=['#b2182b', '#2166ac']
            ),
        ).properties(
            width=width, height=height
        )
        pbar = alt.Chart(pd.DataFrame(dbar)).mark_bar().encode(
            x=alt.X('x:O').title(None).axis(labels=False),
            y=alt.Y('score:Q'),
            color=alt.Color('base:N').scale(
                domain=domain2, range=range2_
            ),
        ).properties(
            width=width, height=height
        )
        pmerge = pheat & pbar & pline
        pmerge = pmerge.configure_axis(grid=False)
        # Save the plot
        if save_path:
            pmerge.save(save_path)
            print(f"Mutation effects visualization saved to {save_path}")
    return pheat

plot_scatter(data, show_score=True, ncols=3, width=400, height=400, save_path=None, separate=False)

Plot scatter chart. Args: data (dict): Data to be plotted. show_score (bool): Whether to show the score on the plot. ncols (int): Number of columns in the plot. width (int): Width of the plot. height (int): Height of the plot. save_path (str): Path to save the plot. Returns: pdots: Altair chart object.

Source code in dnallm/inference/plot.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def plot_scatter(data: dict, show_score: bool = True, ncols: int = 3,
                 width: int = 400, height: int = 400,
                 save_path: str = None, separate: bool=False) -> alt.Chart:
    """
    Plot scatter chart.
    Args:
        data (dict): Data to be plotted.
        show_score (bool): Whether to show the score on the plot.
        ncols (int): Number of columns in the plot.
        width (int): Width of the plot.
        height (int): Height of the plot.
        save_path (str): Path to save the plot.
    Returns:
        pdots: Altair chart object.
    """
    # Plot bar charts
    pdot = {}
    p_separate = {}
    for n, model in enumerate(data):
        scatter_data = dict(data[model])
        r2 = scatter_data['r2']
        del scatter_data['r2']
        ddot = pd.DataFrame(scatter_data)
        dot = alt.Chart(ddot, title=model).mark_point(filled=True).encode(
            x=alt.X("predicted:Q"),
            y=alt.Y("experiment:Q"),
        ).properties(width=width, height=height)
        if show_score:
            min_x = ddot['predicted'].min()
            max_y = ddot['experiment'].max()
            text = alt.Chart().mark_text(size=14, align="left", baseline="bottom").encode(
                x=alt.datum(min_x + 0.5),
                y=alt.datum(max_y - 0.5),
                text=alt.datum("R\u00b2=" + str(r2))
            )
            p = dot + text
        else:
            p = dot
        if separate:
            p_separate[model] = p.configure_axis(grid=False)
        idx = n // ncols
        if n % ncols == 0:
            pdot[idx] = p
        else:
            pdot[idx] |= p
    # Combine the plots
    for i, p in enumerate(pdot):
        if i == 0:
            pdots = pdot[p]
        else:
            pdots &= pdot[p]
    # Configure the chart
    pdots = pdots.configure_axis(grid=False)
    # Save the plot
    if save_path:
        pdots.save(save_path)
        print(f"Metrics scatter plots saved to {save_path}")
    if separate:
        return p_separate
    else:
        return pdots

prepare_data(metrics, task_type='binary')

Prepare data for plotting. Args: metrics (dict): Dictionary containing model metrics. task_type (str): Task type Returns: tuple: Tuple containing bar data and curve data.

Source code in dnallm/inference/plot.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def prepare_data(metrics: dict, task_type: str="binary") -> tuple:
    """
    Prepare data for plotting.
    Args:
        metrics (dict): Dictionary containing model metrics.
        task_type (str): Task type
    Returns:
        tuple: Tuple containing bar data and curve data.
    """
    # Load the data
    bars_data = {'models': []}
    if task_type in ['binary', 'multiclass', 'multilabel', 'token']:
        curves_data = {'ROC': {'models': [], 'fpr': [], 'tpr': []},
                    'PR': {'models': [], 'recall': [], 'precision': []}}
        for model in metrics:
            if model not in bars_data['models']:
                bars_data['models'].append(model)
            for metric in metrics[model]:
                if metric == 'curve':
                    for score in metrics[model][metric]:
                        if score.endswith('pr'):
                            if score == 'fpr':
                                curves_data['ROC']['models'].extend([model] * len(metrics[model][metric][score]))
                            curves_data['ROC'][score].extend(metrics[model][metric][score])
                        else:
                            if score == 'precision':
                                curves_data['PR']['models'].extend([model] * len(metrics[model][metric][score]))
                            curves_data['PR'][score].extend(metrics[model][metric][score])
                else:
                    if metric not in bars_data:
                        bars_data[metric] = []
                    bars_data[metric].append(metrics[model][metric])
        return bars_data, curves_data
    elif task_type == "regression":
        scatter_data = {}
        for model in metrics:
            if model not in bars_data['models']:
                bars_data['models'].append(model)
            scatter_data[model] = {'predicted': [], 'experiment': []}
            for metric in metrics[model]:
                if metric == 'scatter':
                    for score in metrics[model][metric]:
                        scatter_data[model][score].extend(metrics[model][metric][score])
                else:
                    if metric not in bars_data:
                        bars_data[metric] = []
                    bars_data[metric].append(metrics[model][metric])
                    if metric == 'r2':
                        scatter_data[model][metric] = metrics[model][metric]
        return bars_data, scatter_data
    else:
        raise ValueError(f"Unsupport task type {task_type} for ploting")

save_metrics(metrics, output_dir)

Save metrics to files.

Parameters:

Name Type Description Default
metrics Dict

Dictionary containing metrics.

required
output_dir Path

Directory to save metrics.

required
Source code in dnallm/inference/predictor.py
591
592
593
594
595
596
597
598
599
600
601
602
def save_metrics(metrics: Dict, output_dir: Path) -> None:
    """Save metrics to files.

    Args:
        metrics: Dictionary containing metrics.
        output_dir: Directory to save metrics.
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save metrics
    with open(output_dir / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=4)

save_predictions(predictions, output_dir)

Save predictions to files.

Parameters:

Name Type Description Default
predictions Dict

Dictionary containing predictions.

required
output_dir Path

Directory to save predictions.

required
Source code in dnallm/inference/predictor.py
578
579
580
581
582
583
584
585
586
587
588
589
def save_predictions(predictions: Dict, output_dir: Path) -> None:
    """Save predictions to files.

    Args:
        predictions: Dictionary containing predictions.
        output_dir: Directory to save predictions.
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save predictions
    with open(output_dir / "predictions.json", "w") as f:
        json.dump(predictions, f, indent=4)