Skip to content

datasets/data API

DNADataset

Source code in dnallm/datasets/data.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
class DNADataset:
    def __init__(self, ds: Union[Dataset, DatasetDict], tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512):
        """
        Args:
            ds (datasets.Dataset or DatasetDict): A Hugging Face Dataset containing at least 'sequence' and 'label' fields.
            tokenizer (PreTrainedTokenizerBase, optional): A Hugging Face tokenizer for encoding sequences.
            max_length (int, optional): Maximum length for tokenization.
        """
        self.dataset = ds
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sep = None
        self.multi_label_sep = None
        self.stats = None

    @classmethod
    def load_local_data(cls, file_paths, seq_col: str = "sequence", label_col: str = "labels",
                        sep: str = None, fasta_sep: str = "|",
                        multi_label_sep: Union[str, None] = None,
                        tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512) -> any:
        """
        Load DNA sequence datasets from one or multiple local files.

        Supports input formats: csv, tsv, json, parquet, arrow, dict, fasta, txt.

        Args:
            file_paths (str, list, or dict):  
                - Single dataset: Provide one file path (e.g., "data.csv").  
                - Pre-split datasets: Provide a dict like `{"train": "train.csv", "test": "test.csv"}`.
            seq_col (str): Column name for DNA sequences.
            label_col (str): Column name for labels.
            sep (str, optional): Delimiter for CSV, TSV, or TXT.
            fasta_sep (str, optional): Delimiter for FASTA files.
            multi_label_sep (str, optional): Delimiter for multi-label sequences.
            tokenizer (PreTrainedTokenizerBase, optional): A tokenizer.
            max_length (int, optional): Max token length.

        Returns:
            DNADataset: An instance wrapping a Dataset or DatasetDict.
        """
        # Set separators
        cls.sep = sep
        cls.multi_label_sep = multi_label_sep
        # Check if input is a list or dict
        if isinstance(file_paths, dict):  # Handling multiple files (pre-split datasets)
            ds_dict = {}
            for split, path in file_paths.items():
                ds_dict[split] = cls._load_single_data(path, seq_col, label_col, sep, fasta_sep, multi_label_sep)
            dataset = DatasetDict(ds_dict)
        else:  # Handling a single file
            dataset = cls._load_single_data(file_paths, seq_col, label_col, sep, fasta_sep, multi_label_sep)

        return cls(dataset, tokenizer=tokenizer, max_length=max_length)

    @classmethod
    def _load_single_data(cls, file_path, seq_col: str = "sequence", label_col: str = "labels",
                          sep: str = None, fasta_sep: str = "|",
                          multi_label_sep: Union[str, None] = None) -> Dataset:
        """
        Load DNA data (sequences and labels) from a local file.

        Supported file types: 
          - For structured formats (CSV, TSV, JSON, Parquet, Arrow, dict), uses load_dataset from datasets.
          - For FASTA and TXT, uses custom parsing.

        Args:
            file_path: For most file types, a path (or pattern) to the file(s). For 'dict', a dictionary.
            seq_col (str): Name of the column containing the DNA sequence.
            label_col (str): Name of the column containing the label.
            sep (str, optional): Delimiter for CSV, TSV, or TXT files.
            fasta_sep (str, optional): Delimiter for FASTA files.
            multi_label_sep (str, optional): Delimiter for multi-label sequences.

        Returns:
            DNADataset: An instance wrapping a datasets.Dataset.
        """
        if isinstance(file_path, list):
            file_path = [os.path.expanduser(fpath) for fpath in file_path]
            file_type = os.path.basename(file_path[0]).split(".")[-1].lower()
        else:
            file_path = os.path.expanduser(file_path)
            file_type = os.path.basename(file_path).split(".")[-1].lower()
        # Define data type
        default_types = ["csv", "tsv", "json", "parquet", "arrow"]
        dict_types = ["pkl", "pickle", "dict"]
        fasta_types = ["fa", "fna", "fas", "fasta"]
        # Check if the file contains a header
        if file_type in ["csv", "tsv", "txt"]:
            if file_type == "csv":
                sep = sep if sep else ","
            with open(file_path, "r") as f:
                header = f.readline().strip()
                if not header or (seq_col not in header and label_col not in header):
                    file_type = "txt"  # Treat as TXT if no header found
        # For structured formats that load via datasets.load_dataset
        if file_type in default_types:
            if file_type in ["csv", "tsv"]:
                sep = sep or ("," if file_type == "csv" else "\t")
                ds = load_dataset("csv", data_files=file_path, split="train", delimiter=sep)
            elif file_type == "json":
                ds = load_dataset("json", data_files=file_path, split="train")
            elif file_type in ["parquet", "arrow"]:
                ds = load_dataset(file_type, data_files=file_path, split="train")
            # Rename columns if needed
            if seq_col != "sequence":
                ds = ds.rename_column(seq_col, "sequence")
            if label_col != "labels":
                ds = ds.rename_column(label_col, "labels")
        elif file_type in dict_types:
            # Here, file_path is assumed to be a dictionary.
            import pickle
            data = pickle.load(open(file_path, 'rb'))
            ds = Dataset.from_dict(data)
            if seq_col != "sequence" or label_col != "labels":
                if seq_col in ds.column_names:
                    if "sequence" not in ds.features:
                        ds = ds.rename_column(seq_col, "sequence")
                if label_col in ds.column_names:
                    if "labels" not in ds.features:
                        ds = ds.rename_column(label_col, "labels")
        elif file_type in fasta_types:
            sequences, labels = [], []
            with open(file_path, "r") as f:
                seq = ""
                lab = None
                for line in f:
                    line = line.strip()
                    if line.startswith(">"):
                        if seq and lab is not None:
                            sequences.append(seq)
                            labels.append(lab)
                        lab = line[1:].strip().split(fasta_sep)[-1]  # Assume label is separated by `fasta_sep` in the header
                        seq = ""
                    else:
                        seq += line.strip()
                if seq and lab is not None:
                    sequences.append(seq)
                    labels.append(lab)
            ds = Dataset.from_dict({"sequence": sequences, "labels": labels})
        elif file_type == "txt":
            # Assume each line contains a sequence and a label separated by whitespace or a custom sep.
            sequences, labels = [], []
            with open(file_path, "r") as f:
                for i,line in enumerate(f):
                    if i == 0:
                        # Contain header, use load_dataset with csv method
                        if seq_col in line and label_col in line:
                            ds = load_dataset("csv", data_files=file_path, split="train", delimiter=sep)
                            break
                    record = line.strip().split(sep) if sep else line.strip().split()
                    if len(record) >= 2:
                        sequences.append(record[0])
                        labels.append(record[1])
                    else:
                        continue
            ds = Dataset.from_dict({"sequence": sequences, "labels": labels})
        else:
            raise ValueError(f"Unsupported file type: {file_type}")
        # Convert string labels to integer
        def format_labels(example):
            labels = example['labels']
            if isinstance(labels, str):
                if multi_label_sep:
                    example['labels'] = [float(x) for x in labels.split(multi_label_sep)]
                else:
                    example['labels'] = float(labels) if '.' in labels else int(labels)
            return example
        if 'labels' in ds.column_names:
            ds = ds.map(format_labels, desc="Format labels")
        # Return processed dataset
        return ds

    @classmethod
    def from_huggingface(cls, dataset_name: str,
                         seq_col: str = "sequence", label_col: str = "labels",
                         data_dir: Union[str, None]=None,
                         tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512) -> any:
        """
        Load a dataset from the Hugging Face Hub.

        Args:
            dataset_name (str): Name of the dataset.
            seq_col (str): Column name for the DNA sequence.
            label_col (str): Column name for the label.
            data_dir (str): Data directory in a dataset.
            tokenizer (PreTrainedTokenizerBase): Tokenizer.
            max_length (int): Max token length.

        Returns:
            DNADataset: An instance wrapping a datasets.Dataset.
        """
        if data_dir:
            ds = load_dataset(dataset_name, data_dir=data_dir)
        else:
            ds = load_dataset(dataset_name)
        # Rename columns if necessary
        if seq_col != "sequence":
            ds = ds.rename_column(seq_col, "sequence")
        if label_col != "labels":
            ds = ds.rename_column(label_col, "labels")
        return cls(ds, tokenizer=tokenizer, max_length=max_length)

    @classmethod
    def from_modelscope(cls, dataset_name: str,
                        seq_col: str = "sequence", label_col: str = "labels",
                        data_dir: Union[str, None]=None,
                        tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512) -> any:
        """
        Load a dataset from the ModelScope.

        Args:
            dataset_name (str): Name of the dataset.
            seq_col (str): Column name for the DNA sequence.
            label_col (str): Column name for the label.
            data_dir (str): Data directory in a dataset.
            tokenizer: Tokenizer.
            max_length: Max token length.

        Returns:
            DNADataset: An instance wrapping a datasets.Dataset.
        """
        from modelscope import MsDataset

        if data_dir:
            ds = MsDataset.load(dataset_name, data_dir=data_dir)
        else:
            ds = MsDataset.load(dataset_name)
        # Rename columns if necessary
        if seq_col != "sequence":
            ds = ds.rename_column(seq_col, "sequence")
        if label_col != "labels":
            ds = ds.rename_column(label_col, "labels")
        return cls(ds, tokenizer=tokenizer, max_length=max_length)

    def encode_sequences(self, padding: str = "max_length", return_tensors: str = "pt",
                         remove_unused_columns: bool = False,
                         uppercase: bool=False, lowercase: bool=False,
                         task: Optional[str] = 'SequenceClassification'):
        """
        Encode all sequences using the provided tokenizer.
        The dataset is mapped to include tokenized fields along with the label,
        making it directly usable with Hugging Face Trainer.

        Args:
            padding (str): Padding strategy for sequences. this can be 'max_length' or 'longest'.
                           Use 'longest' to pad to the length of the longest sequence in case of memory outage.
            return_tensors (str | TensorType): Returned tensor types, can be 'pt' or 'tf' or 'np'.
            remove_unused_columns: Whether to remove the original 'sequence' and 'label' columns
            uppercase (bool): Whether to convert sequences to uppercase.
            lowercase (bool): Whether to convert sequences to lowercase.
            task (str, optional): Task type for the tokenizer. If not provided, defaults to 'SequenceClassification'.
        """
        if self.tokenizer:
            sp_token_map = self.tokenizer.special_tokens_map
            pad_token = sp_token_map['pad_token'] if 'pad_token' in sp_token_map else None
            pad_id = self.tokenizer.encode(pad_token)[-1] if pad_token else None
            cls_token = sp_token_map['cls_token'] if 'cls_token' in sp_token_map else None
            sep_token = sp_token_map['sep_token'] if 'sep_token' in sp_token_map else None
            max_length = self.max_length
        else:
            raise ValueError("Tokenizer not provided.")
        def tokenize_for_sequence_classification(example):
            sequences = example["sequence"]
            if uppercase:
                sequences = [x.upper() for x in sequences]
            if lowercase:
                sequences = [x.lower() for x in sequences]
            tokenized = self.tokenizer(
                sequences,
                truncation=True,
                padding=padding,
                max_length=max_length
            )
            return tokenized
        def tokenize_for_token_classification(examples):

            tokenized_examples = {'sequence': [],
                                  'input_ids': [],
                                  # 'token_type_ids': [],
                                  'attention_mask': []}
            if 'labels' in examples:
                tokenized_examples['labels'] = []
            input_seqs = examples['sequence']
            if isinstance(input_seqs, str):
                input_seqs = input_seqs.split(self.multi_label_sep)
            for i, example_tokens in enumerate(input_seqs):
                all_ids = [x for x in self.tokenizer.encode(example_tokens, is_split_into_words=True) if x>=0]
                if 'labels' in examples:
                    example_ner_tags = examples['labels'][i]
                else:
                    example_ner_tags = [0] * len(example_tokens)
                pad_len = max_length - len(all_ids)
                if pad_len >= 0:
                    all_masks = [1] * len(all_ids) + [0] * pad_len
                    all_ids = all_ids + [pad_id] * pad_len
                    if cls_token:
                        if sep_token:
                            example_tokens = [cls_token] + example_tokens + [sep_token] + [pad_token] * pad_len
                            example_ner_tags = [-100] + example_ner_tags + [-100] * (pad_len + 1)
                        else:
                            example_tokens = [cls_token] + example_tokens + [pad_token] * pad_len
                            example_ner_tags = [-100] + example_ner_tags + [-100] * pad_len
                    else:
                        example_tokens = example_tokens + [pad_token] * pad_len
                        example_ner_tags = example_ner_tags + [-100] * pad_len
                elif pad_len < 0:
                    all_ids = all_ids[:max_length]
                    all_masks = [1] * (max_length)
                    if cls_token:
                        if sep_token:
                            example_tokens = [cls_token] + example_tokens[:max_length - 2] + [sep_token]
                            example_ner_tags = [-100] + example_ner_tags[:max_length - 2] + [-100]
                        else:
                            example_tokens = [cls_token] + example_tokens[:max_length - 1]
                            example_ner_tags = [-100] + example_ner_tags[:max_length - 1]
                    else:
                        example_tokens = example_tokens[:max_length]
                        example_ner_tags = example_ner_tags[:max_length]
                tokenized_examples['sequence'].append(example_tokens)
                tokenized_examples['input_ids'].append(all_ids)
                # tokenized_examples['token_type_ids'].append([0] * max_length)
                tokenized_examples['attention_mask'].append(all_masks)
                if 'labels' in examples:
                    tokenized_examples['labels'].append(example_ner_tags)
            return BatchEncoding(tokenized_examples)
        # Judge the task type
        task = task.lower()
        if task in ['sequenceclassification', 'binary', 'multiclass', 'multilabel', 'regression']:
            self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True, desc="Encoding inputs")
        elif task in ['tokenclassification', 'token', 'ner']:
            from transformers.tokenization_utils_base import BatchEncoding
            self.dataset = self.dataset.map(tokenize_for_token_classification, batched=True, desc="Encoding inputs")
        elif task in ['maskedlm', 'mlm', 'mask', 'embedding']:
            self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True, desc="Encoding inputs")
        elif task in ['causallm', 'clm', 'causal', 'generation', 'embedding']:
            self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True)
        else:
            self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True, desc="Encoding inputs")
        if remove_unused_columns:
            used_cols = ['labels', 'input_ids', 'attention_mask']
            if isinstance(self.dataset, DatasetDict):
                for dt in self.dataset:
                    unused_cols = [f for f in self.dataset[dt].features if f not in used_cols]
                    self.dataset[dt] = self.dataset[dt].remove_columns(unused_cols)
            else:
                unused_cols = [f for f in self.dataset.features if f not in used_cols]
                self.dataset = self.dataset.remove_columns(unused_cols)
        if return_tensors == "tf":
            self.dataset.set_format(type="tensorflow")
        elif return_tensors == "jax":
            self.dataset.set_format(type="jax")
        elif return_tensors == "np":
            self.dataset.set_format(type="numpy")
        else:
            self.dataset.set_format(type="torch")

    def split_data(self, test_size: float = 0.2, val_size: float = 0.1, seed: int = None):
        """
        Split the dataset into train, test, and validation sets.

        Args:
            test_size (float): Proportion of the dataset to include in the test split.
            val_size (float): Proportion of the dataset to include in the validation split.
            seed (int): Random seed for reproducibility.
        """
        # First, split off test+validation from training data
        split_result = self.dataset.train_test_split(test_size=test_size + val_size, seed=seed)
        train_ds = split_result['train']
        temp_ds = split_result['test']
        # Further split temp_ds into test and validation sets
        if val_size > 0:
            rel_val_size = val_size / (test_size + val_size)
            temp_split = temp_ds.train_test_split(test_size=rel_val_size, seed=seed)
            test_ds = temp_split['train']
            val_ds = temp_split['test']
            self.dataset = DatasetDict({'train': train_ds, 'test': test_ds, 'val': val_ds})
        else:
            self.dataset = DatasetDict({'train': train_ds, 'test': test_ds})

    def shuffle(self, seed: int = None):
        """
        Shuffle the dataset.

        Args:
            seed (int): Random seed for reproducibility.
        """
        self.dataset.shuffle(seed=seed)

    def validate_sequences(self, minl: int = 20, maxl: int = 6000, gc: tuple = (0,1), valid_chars: str = "ACGTN"):
        """
        Filter the dataset to keep sequences containing valid DNA bases or allowed length.

        Args:
            minl (int): Minimum length of the sequences.
            maxl (int): Maximum length of the sequences.
            gc (tuple): GC content range between 0 and 1.
            valid_chars (str): Allowed characters in the sequences.
        """
        self.dataset = self.dataset.filter(
            lambda example: check_sequence(example["sequence"], minl, maxl, gc, valid_chars)
        )

    def random_generate(self, minl: int, maxl: int = 0, samples: int = 1,
                              gc: tuple = (0,1), N_ratio: float = 0.0,
                              padding_size: int = 0, seed: int = None,
                              label_func = None, append: bool = False):
        """
        Replace the current dataset with randomly generated DNA sequences.

        Args:
            minl: int, minimum length of the sequences
            maxl: int, maximum length of the sequences, default is the same as minl
            samples: int, number of sequences to generate, default 1
            gc: tuple, GC content range, default (0,1)
            N_ratio: float, include N base in the generated sequence, default 0.0
            padding_size: int, padding size for sequence length, default 0
            seed: int, random seed, default None
            label_func (callable, optional): A function that generates a label from a sequence.
            append: bool, append the random generated data to the existed dataset or use the data as a dataset
        """
        def process(minl, maxl, number, gc, N_ratio, padding_size, seed, label_func):
            sequences = random_generate_sequences(minl=minl, maxl=maxl, samples=number,
                                                gc=gc, N_ratio=N_ratio,
                                                padding_size=padding_size, seed=seed)
            labels = []
            for seq in sequences:
                labels.append(label_func(seq) if label_func else 0)
            random_ds = Dataset.from_dict({"sequence": sequences, "labels": labels})
            return random_ds
        if append:
            if isinstance(self.dataset, DatasetDict):
                for dt in self.dataset:
                    number = round(samples * len(self.dataset[dt]) / sum(self.__len__().values()))
                    random_ds = process(minl, maxl, number, gc, N_ratio, padding_size, seed, label_func)
                    self.dataset[dt] = concatenate_datasets([self.dataset[dt], random_ds])
            else:
                random_ds = process(minl, maxl, samples, gc, N_ratio, padding_size, seed, label_func)
                self.dataset = concatenate_datasets([self.dataset, random_ds])
        else:
            self.dataset = process(minl, maxl, samples, gc, N_ratio, padding_size, seed, label_func)

    def process_missing_data(self):
        """
        Filter out samples with missing or empty sequences or labels.
        """
        def non_missing(example):
            return example["sequence"] and example["labels"] is not None and example["sequence"].strip() != ""
        self.dataset = self.dataset.filter(non_missing)

    def raw_reverse_complement(self, ratio: float = 0.5, seed: int = None):
        """
        Do reverse complement of sequences in the dataset.

        Args:
            ratio (float): Ratio of sequences to reverse complement.
            seed (int): Random seed for reproducibility.
        """
        def process(ds, ratio, seed):
            random.seed(seed)
            number = len(ds["sequence"])
            idxlist = set(random.sample(range(number), int(number * ratio)))
            def concat_fn(example, idx):
                rc = reverse_complement(example["sequence"])
                if idx in idxlist:
                    example["sequence"] = rc
                return example
            # Create a dataset with random reverse complement.
            ds.map(concat_fn, with_indices=True, desc="Reverse complementary")
            return ds
        if isinstance(self.dataset, DatasetDict):
            for dt in self.dataset:
                self.dataset[dt] = process(self.dataset[dt], ratio, seed)
        else:
            self.dataset = process(self.dataset, ratio, seed)

    def augment_reverse_complement(self, reverse=True, complement=True):
        """
        Augment the dataset by adding reverse complement sequences.
        This method doubles the dataset size.

        Args:
            reverse (bool): Whether to do reverse.
            complement (bool): Whether to do complement.
        """
        def process(ds, reverse, complement):
            # Create a dataset with an extra field for the reverse complement.
            def add_rc(example):
                example["rc_sequence"] = reverse_complement(
                    example["sequence"], reverse=reverse, complement=complement
                )
                return example
            ds_with_rc = ds.map(add_rc, desc="Reverse complementary")
            # Build a new dataset where the reverse complement becomes the 'sequence'
            rc_ds = ds_with_rc.map(lambda ex: {"sequence": ex["rc_sequence"], "labels": ex["labels"]}, desc="Data augment")
            ds = concatenate_datasets([ds, rc_ds])
            ds.remove_columns(["rc_sequence"])
            return ds
        if isinstance(self.dataset, DatasetDict):
            for dt in self.dataset:
                self.dataset[dt] = process(self.dataset[dt], reverse, complement)
        else:
            self.dataset = process(self.dataset, reverse, complement)

    def concat_reverse_complement(self, reverse=True, complement=True, sep: str = ""):
        """
        Augment each sample by concatenating the sequence with its reverse complement.

        Args:
            reverse (bool): Whether to do reverse.
            complement (bool): Whether to do complement.
            sep (str): Separator between the original and reverse complement sequences.
        """
        def process(ds, reverse, complement, sep):
            def concat_fn(example):
                rc = reverse_complement(example["sequence"], reverse=reverse, complement=complement)
                example["sequence"] = example["sequence"] + sep + rc
                return example
            ds = ds.map(concat_fn, desc="Data augment")
            return ds
        if isinstance(self.dataset, DatasetDict):
            for dt in self.dataset:
                self.dataset[dt] = process(self.dataset[dt], reverse, complement, sep)
        else:
            self.dataset = process(self.dataset, reverse, complement, sep)

    def sampling(self, ratio: float=1.0, seed: int = None, overwrite: bool=False) -> any:
        """
        Randomly sample a fraction of the dataset.

        Args:
            ratio (float): Fraction of the dataset to sample. Default is 1.0 (no sampling).
            seed (int): Random seed for reproducibility.
            overwrite (bool): Whether to overwrite the original dataset with the sampled one.

        Returns:
            A sampled dataset.
        """
        dataset = self.dataset
        if isinstance(dataset, DatasetDict):
            for dt in dataset.keys():
                random.seed(seed)
                random_idx = random.sample(range(len(dataset[dt])), int(len(dataset[dt]) * ratio))
                dataset[dt] = dataset[dt].select(random_idx)
        else:
            random_idx = random.sample(range(len(dataset)), int(len(dataset) * ratio))
            dataset = dataset.select(random_idx)
        if overwrite:
            self.dataset = dataset
        else:
            return dataset

    def head(self, head: int=10, show: bool=False) -> dict:
        """
        Fetch the head n data from the dataset

        Args:
            head (int): Number of samples to fetch.
            show (bool): Whether to print the data or return it.

        Returns:
            dict: A dictionary containing the first n samples.
        """
        import pprint
        def format_convert(data):
            df = {}
            length = len(data["sequence"])
            for i in range(length):
                df[i] = {}
                for key in data.keys():
                    df[i][key] = data[key][i]
            return df
        dataset = self.dataset
        if isinstance(dataset, DatasetDict):
            df = {}
            for dt in dataset.keys():
                data = dataset[dt][:head]
                if show:
                    print(f"Dataset: {dt}")
                    pprint.pp(format_convert(data))
                else:
                    df[dt] = data
                    return df
        else:
            data = dataset[dt][:head]
            if show:
                pprint.pp(format_convert(data))
            else:
                return data

    def show(self, head: int=10):
        """
        Display the dataset

        Args:
            head (int): Number of samples to display.
        """
        self.head(head=head, show=True)            

    def iter_batches(self, batch_size: int) -> Dataset:
        """
        Generator that yields batches of examples from the dataset.

        Args:
            batch_size (int): Size of each batch.

        Yields:
            A batch of examples.
        """
        if isinstance(self.dataset, DatasetDict):
            raise ValueError("Dataset is a DatasetDict Object, please use `DNADataset.dataset[datatype].iter_batches(batch_size)` instead.")
        else:
            for i in range(0, len(self.dataset), batch_size):
                yield self.dataset[i: i + batch_size]

    def __len__(self):
        if isinstance(self.dataset, DatasetDict):
            return {dt: len(self.dataset[dt]) for dt in self.dataset}
        else:
            return len(self.dataset)

    def __getitem__(self, idx):
        if isinstance(self.dataset, DatasetDict):
            raise ValueError("Dataset is a DatasetDict Object, please use `DNADataset.dataset[datatype].__getitem__(idx)` instead.")
        else:
            return self.dataset[idx]

__init__(ds, tokenizer=None, max_length=512)

Parameters:

Name Type Description Default
ds Dataset or DatasetDict

A Hugging Face Dataset containing at least 'sequence' and 'label' fields.

required
tokenizer PreTrainedTokenizerBase

A Hugging Face tokenizer for encoding sequences.

None
max_length int

Maximum length for tokenization.

512
Source code in dnallm/datasets/data.py
12
13
14
15
16
17
18
19
20
21
22
23
24
def __init__(self, ds: Union[Dataset, DatasetDict], tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512):
    """
    Args:
        ds (datasets.Dataset or DatasetDict): A Hugging Face Dataset containing at least 'sequence' and 'label' fields.
        tokenizer (PreTrainedTokenizerBase, optional): A Hugging Face tokenizer for encoding sequences.
        max_length (int, optional): Maximum length for tokenization.
    """
    self.dataset = ds
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.sep = None
    self.multi_label_sep = None
    self.stats = None

augment_reverse_complement(reverse=True, complement=True)

Augment the dataset by adding reverse complement sequences. This method doubles the dataset size.

Parameters:

Name Type Description Default
reverse bool

Whether to do reverse.

True
complement bool

Whether to do complement.

True
Source code in dnallm/datasets/data.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def augment_reverse_complement(self, reverse=True, complement=True):
    """
    Augment the dataset by adding reverse complement sequences.
    This method doubles the dataset size.

    Args:
        reverse (bool): Whether to do reverse.
        complement (bool): Whether to do complement.
    """
    def process(ds, reverse, complement):
        # Create a dataset with an extra field for the reverse complement.
        def add_rc(example):
            example["rc_sequence"] = reverse_complement(
                example["sequence"], reverse=reverse, complement=complement
            )
            return example
        ds_with_rc = ds.map(add_rc, desc="Reverse complementary")
        # Build a new dataset where the reverse complement becomes the 'sequence'
        rc_ds = ds_with_rc.map(lambda ex: {"sequence": ex["rc_sequence"], "labels": ex["labels"]}, desc="Data augment")
        ds = concatenate_datasets([ds, rc_ds])
        ds.remove_columns(["rc_sequence"])
        return ds
    if isinstance(self.dataset, DatasetDict):
        for dt in self.dataset:
            self.dataset[dt] = process(self.dataset[dt], reverse, complement)
    else:
        self.dataset = process(self.dataset, reverse, complement)

concat_reverse_complement(reverse=True, complement=True, sep='')

Augment each sample by concatenating the sequence with its reverse complement.

Parameters:

Name Type Description Default
reverse bool

Whether to do reverse.

True
complement bool

Whether to do complement.

True
sep str

Separator between the original and reverse complement sequences.

''
Source code in dnallm/datasets/data.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def concat_reverse_complement(self, reverse=True, complement=True, sep: str = ""):
    """
    Augment each sample by concatenating the sequence with its reverse complement.

    Args:
        reverse (bool): Whether to do reverse.
        complement (bool): Whether to do complement.
        sep (str): Separator between the original and reverse complement sequences.
    """
    def process(ds, reverse, complement, sep):
        def concat_fn(example):
            rc = reverse_complement(example["sequence"], reverse=reverse, complement=complement)
            example["sequence"] = example["sequence"] + sep + rc
            return example
        ds = ds.map(concat_fn, desc="Data augment")
        return ds
    if isinstance(self.dataset, DatasetDict):
        for dt in self.dataset:
            self.dataset[dt] = process(self.dataset[dt], reverse, complement, sep)
    else:
        self.dataset = process(self.dataset, reverse, complement, sep)

encode_sequences(padding='max_length', return_tensors='pt', remove_unused_columns=False, uppercase=False, lowercase=False, task='SequenceClassification')

Encode all sequences using the provided tokenizer. The dataset is mapped to include tokenized fields along with the label, making it directly usable with Hugging Face Trainer.

Parameters:

Name Type Description Default
padding str

Padding strategy for sequences. this can be 'max_length' or 'longest'. Use 'longest' to pad to the length of the longest sequence in case of memory outage.

'max_length'
return_tensors str | TensorType

Returned tensor types, can be 'pt' or 'tf' or 'np'.

'pt'
remove_unused_columns bool

Whether to remove the original 'sequence' and 'label' columns

False
uppercase bool

Whether to convert sequences to uppercase.

False
lowercase bool

Whether to convert sequences to lowercase.

False
task str

Task type for the tokenizer. If not provided, defaults to 'SequenceClassification'.

'SequenceClassification'
Source code in dnallm/datasets/data.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def encode_sequences(self, padding: str = "max_length", return_tensors: str = "pt",
                     remove_unused_columns: bool = False,
                     uppercase: bool=False, lowercase: bool=False,
                     task: Optional[str] = 'SequenceClassification'):
    """
    Encode all sequences using the provided tokenizer.
    The dataset is mapped to include tokenized fields along with the label,
    making it directly usable with Hugging Face Trainer.

    Args:
        padding (str): Padding strategy for sequences. this can be 'max_length' or 'longest'.
                       Use 'longest' to pad to the length of the longest sequence in case of memory outage.
        return_tensors (str | TensorType): Returned tensor types, can be 'pt' or 'tf' or 'np'.
        remove_unused_columns: Whether to remove the original 'sequence' and 'label' columns
        uppercase (bool): Whether to convert sequences to uppercase.
        lowercase (bool): Whether to convert sequences to lowercase.
        task (str, optional): Task type for the tokenizer. If not provided, defaults to 'SequenceClassification'.
    """
    if self.tokenizer:
        sp_token_map = self.tokenizer.special_tokens_map
        pad_token = sp_token_map['pad_token'] if 'pad_token' in sp_token_map else None
        pad_id = self.tokenizer.encode(pad_token)[-1] if pad_token else None
        cls_token = sp_token_map['cls_token'] if 'cls_token' in sp_token_map else None
        sep_token = sp_token_map['sep_token'] if 'sep_token' in sp_token_map else None
        max_length = self.max_length
    else:
        raise ValueError("Tokenizer not provided.")
    def tokenize_for_sequence_classification(example):
        sequences = example["sequence"]
        if uppercase:
            sequences = [x.upper() for x in sequences]
        if lowercase:
            sequences = [x.lower() for x in sequences]
        tokenized = self.tokenizer(
            sequences,
            truncation=True,
            padding=padding,
            max_length=max_length
        )
        return tokenized
    def tokenize_for_token_classification(examples):

        tokenized_examples = {'sequence': [],
                              'input_ids': [],
                              # 'token_type_ids': [],
                              'attention_mask': []}
        if 'labels' in examples:
            tokenized_examples['labels'] = []
        input_seqs = examples['sequence']
        if isinstance(input_seqs, str):
            input_seqs = input_seqs.split(self.multi_label_sep)
        for i, example_tokens in enumerate(input_seqs):
            all_ids = [x for x in self.tokenizer.encode(example_tokens, is_split_into_words=True) if x>=0]
            if 'labels' in examples:
                example_ner_tags = examples['labels'][i]
            else:
                example_ner_tags = [0] * len(example_tokens)
            pad_len = max_length - len(all_ids)
            if pad_len >= 0:
                all_masks = [1] * len(all_ids) + [0] * pad_len
                all_ids = all_ids + [pad_id] * pad_len
                if cls_token:
                    if sep_token:
                        example_tokens = [cls_token] + example_tokens + [sep_token] + [pad_token] * pad_len
                        example_ner_tags = [-100] + example_ner_tags + [-100] * (pad_len + 1)
                    else:
                        example_tokens = [cls_token] + example_tokens + [pad_token] * pad_len
                        example_ner_tags = [-100] + example_ner_tags + [-100] * pad_len
                else:
                    example_tokens = example_tokens + [pad_token] * pad_len
                    example_ner_tags = example_ner_tags + [-100] * pad_len
            elif pad_len < 0:
                all_ids = all_ids[:max_length]
                all_masks = [1] * (max_length)
                if cls_token:
                    if sep_token:
                        example_tokens = [cls_token] + example_tokens[:max_length - 2] + [sep_token]
                        example_ner_tags = [-100] + example_ner_tags[:max_length - 2] + [-100]
                    else:
                        example_tokens = [cls_token] + example_tokens[:max_length - 1]
                        example_ner_tags = [-100] + example_ner_tags[:max_length - 1]
                else:
                    example_tokens = example_tokens[:max_length]
                    example_ner_tags = example_ner_tags[:max_length]
            tokenized_examples['sequence'].append(example_tokens)
            tokenized_examples['input_ids'].append(all_ids)
            # tokenized_examples['token_type_ids'].append([0] * max_length)
            tokenized_examples['attention_mask'].append(all_masks)
            if 'labels' in examples:
                tokenized_examples['labels'].append(example_ner_tags)
        return BatchEncoding(tokenized_examples)
    # Judge the task type
    task = task.lower()
    if task in ['sequenceclassification', 'binary', 'multiclass', 'multilabel', 'regression']:
        self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True, desc="Encoding inputs")
    elif task in ['tokenclassification', 'token', 'ner']:
        from transformers.tokenization_utils_base import BatchEncoding
        self.dataset = self.dataset.map(tokenize_for_token_classification, batched=True, desc="Encoding inputs")
    elif task in ['maskedlm', 'mlm', 'mask', 'embedding']:
        self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True, desc="Encoding inputs")
    elif task in ['causallm', 'clm', 'causal', 'generation', 'embedding']:
        self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True)
    else:
        self.dataset = self.dataset.map(tokenize_for_sequence_classification, batched=True, desc="Encoding inputs")
    if remove_unused_columns:
        used_cols = ['labels', 'input_ids', 'attention_mask']
        if isinstance(self.dataset, DatasetDict):
            for dt in self.dataset:
                unused_cols = [f for f in self.dataset[dt].features if f not in used_cols]
                self.dataset[dt] = self.dataset[dt].remove_columns(unused_cols)
        else:
            unused_cols = [f for f in self.dataset.features if f not in used_cols]
            self.dataset = self.dataset.remove_columns(unused_cols)
    if return_tensors == "tf":
        self.dataset.set_format(type="tensorflow")
    elif return_tensors == "jax":
        self.dataset.set_format(type="jax")
    elif return_tensors == "np":
        self.dataset.set_format(type="numpy")
    else:
        self.dataset.set_format(type="torch")

from_huggingface(dataset_name, seq_col='sequence', label_col='labels', data_dir=None, tokenizer=None, max_length=512) classmethod

Load a dataset from the Hugging Face Hub.

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset.

required
seq_col str

Column name for the DNA sequence.

'sequence'
label_col str

Column name for the label.

'labels'
data_dir str

Data directory in a dataset.

None
tokenizer PreTrainedTokenizerBase

Tokenizer.

None
max_length int

Max token length.

512

Returns:

Name Type Description
DNADataset any

An instance wrapping a datasets.Dataset.

Source code in dnallm/datasets/data.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
@classmethod
def from_huggingface(cls, dataset_name: str,
                     seq_col: str = "sequence", label_col: str = "labels",
                     data_dir: Union[str, None]=None,
                     tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512) -> any:
    """
    Load a dataset from the Hugging Face Hub.

    Args:
        dataset_name (str): Name of the dataset.
        seq_col (str): Column name for the DNA sequence.
        label_col (str): Column name for the label.
        data_dir (str): Data directory in a dataset.
        tokenizer (PreTrainedTokenizerBase): Tokenizer.
        max_length (int): Max token length.

    Returns:
        DNADataset: An instance wrapping a datasets.Dataset.
    """
    if data_dir:
        ds = load_dataset(dataset_name, data_dir=data_dir)
    else:
        ds = load_dataset(dataset_name)
    # Rename columns if necessary
    if seq_col != "sequence":
        ds = ds.rename_column(seq_col, "sequence")
    if label_col != "labels":
        ds = ds.rename_column(label_col, "labels")
    return cls(ds, tokenizer=tokenizer, max_length=max_length)

from_modelscope(dataset_name, seq_col='sequence', label_col='labels', data_dir=None, tokenizer=None, max_length=512) classmethod

Load a dataset from the ModelScope.

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset.

required
seq_col str

Column name for the DNA sequence.

'sequence'
label_col str

Column name for the label.

'labels'
data_dir str

Data directory in a dataset.

None
tokenizer PreTrainedTokenizerBase

Tokenizer.

None
max_length int

Max token length.

512

Returns:

Name Type Description
DNADataset any

An instance wrapping a datasets.Dataset.

Source code in dnallm/datasets/data.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
@classmethod
def from_modelscope(cls, dataset_name: str,
                    seq_col: str = "sequence", label_col: str = "labels",
                    data_dir: Union[str, None]=None,
                    tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512) -> any:
    """
    Load a dataset from the ModelScope.

    Args:
        dataset_name (str): Name of the dataset.
        seq_col (str): Column name for the DNA sequence.
        label_col (str): Column name for the label.
        data_dir (str): Data directory in a dataset.
        tokenizer: Tokenizer.
        max_length: Max token length.

    Returns:
        DNADataset: An instance wrapping a datasets.Dataset.
    """
    from modelscope import MsDataset

    if data_dir:
        ds = MsDataset.load(dataset_name, data_dir=data_dir)
    else:
        ds = MsDataset.load(dataset_name)
    # Rename columns if necessary
    if seq_col != "sequence":
        ds = ds.rename_column(seq_col, "sequence")
    if label_col != "labels":
        ds = ds.rename_column(label_col, "labels")
    return cls(ds, tokenizer=tokenizer, max_length=max_length)

head(head=10, show=False)

Fetch the head n data from the dataset

Parameters:

Name Type Description Default
head int

Number of samples to fetch.

10
show bool

Whether to print the data or return it.

False

Returns:

Name Type Description
dict dict

A dictionary containing the first n samples.

Source code in dnallm/datasets/data.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def head(self, head: int=10, show: bool=False) -> dict:
    """
    Fetch the head n data from the dataset

    Args:
        head (int): Number of samples to fetch.
        show (bool): Whether to print the data or return it.

    Returns:
        dict: A dictionary containing the first n samples.
    """
    import pprint
    def format_convert(data):
        df = {}
        length = len(data["sequence"])
        for i in range(length):
            df[i] = {}
            for key in data.keys():
                df[i][key] = data[key][i]
        return df
    dataset = self.dataset
    if isinstance(dataset, DatasetDict):
        df = {}
        for dt in dataset.keys():
            data = dataset[dt][:head]
            if show:
                print(f"Dataset: {dt}")
                pprint.pp(format_convert(data))
            else:
                df[dt] = data
                return df
    else:
        data = dataset[dt][:head]
        if show:
            pprint.pp(format_convert(data))
        else:
            return data

iter_batches(batch_size)

Generator that yields batches of examples from the dataset.

Parameters:

Name Type Description Default
batch_size int

Size of each batch.

required

Yields:

Type Description
Dataset

A batch of examples.

Source code in dnallm/datasets/data.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def iter_batches(self, batch_size: int) -> Dataset:
    """
    Generator that yields batches of examples from the dataset.

    Args:
        batch_size (int): Size of each batch.

    Yields:
        A batch of examples.
    """
    if isinstance(self.dataset, DatasetDict):
        raise ValueError("Dataset is a DatasetDict Object, please use `DNADataset.dataset[datatype].iter_batches(batch_size)` instead.")
    else:
        for i in range(0, len(self.dataset), batch_size):
            yield self.dataset[i: i + batch_size]

load_local_data(file_paths, seq_col='sequence', label_col='labels', sep=None, fasta_sep='|', multi_label_sep=None, tokenizer=None, max_length=512) classmethod

Load DNA sequence datasets from one or multiple local files.

Supports input formats: csv, tsv, json, parquet, arrow, dict, fasta, txt.

Parameters:

Name Type Description Default
file_paths str, list, or dict
  • Single dataset: Provide one file path (e.g., "data.csv").
  • Pre-split datasets: Provide a dict like {"train": "train.csv", "test": "test.csv"}.
required
seq_col str

Column name for DNA sequences.

'sequence'
label_col str

Column name for labels.

'labels'
sep str

Delimiter for CSV, TSV, or TXT.

None
fasta_sep str

Delimiter for FASTA files.

'|'
multi_label_sep str

Delimiter for multi-label sequences.

None
tokenizer PreTrainedTokenizerBase

A tokenizer.

None
max_length int

Max token length.

512

Returns:

Name Type Description
DNADataset any

An instance wrapping a Dataset or DatasetDict.

Source code in dnallm/datasets/data.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@classmethod
def load_local_data(cls, file_paths, seq_col: str = "sequence", label_col: str = "labels",
                    sep: str = None, fasta_sep: str = "|",
                    multi_label_sep: Union[str, None] = None,
                    tokenizer: PreTrainedTokenizerBase = None, max_length: int = 512) -> any:
    """
    Load DNA sequence datasets from one or multiple local files.

    Supports input formats: csv, tsv, json, parquet, arrow, dict, fasta, txt.

    Args:
        file_paths (str, list, or dict):  
            - Single dataset: Provide one file path (e.g., "data.csv").  
            - Pre-split datasets: Provide a dict like `{"train": "train.csv", "test": "test.csv"}`.
        seq_col (str): Column name for DNA sequences.
        label_col (str): Column name for labels.
        sep (str, optional): Delimiter for CSV, TSV, or TXT.
        fasta_sep (str, optional): Delimiter for FASTA files.
        multi_label_sep (str, optional): Delimiter for multi-label sequences.
        tokenizer (PreTrainedTokenizerBase, optional): A tokenizer.
        max_length (int, optional): Max token length.

    Returns:
        DNADataset: An instance wrapping a Dataset or DatasetDict.
    """
    # Set separators
    cls.sep = sep
    cls.multi_label_sep = multi_label_sep
    # Check if input is a list or dict
    if isinstance(file_paths, dict):  # Handling multiple files (pre-split datasets)
        ds_dict = {}
        for split, path in file_paths.items():
            ds_dict[split] = cls._load_single_data(path, seq_col, label_col, sep, fasta_sep, multi_label_sep)
        dataset = DatasetDict(ds_dict)
    else:  # Handling a single file
        dataset = cls._load_single_data(file_paths, seq_col, label_col, sep, fasta_sep, multi_label_sep)

    return cls(dataset, tokenizer=tokenizer, max_length=max_length)

process_missing_data()

Filter out samples with missing or empty sequences or labels.

Source code in dnallm/datasets/data.py
452
453
454
455
456
457
458
def process_missing_data(self):
    """
    Filter out samples with missing or empty sequences or labels.
    """
    def non_missing(example):
        return example["sequence"] and example["labels"] is not None and example["sequence"].strip() != ""
    self.dataset = self.dataset.filter(non_missing)

random_generate(minl, maxl=0, samples=1, gc=(0, 1), N_ratio=0.0, padding_size=0, seed=None, label_func=None, append=False)

Replace the current dataset with randomly generated DNA sequences.

Parameters:

Name Type Description Default
minl int

int, minimum length of the sequences

required
maxl int

int, maximum length of the sequences, default is the same as minl

0
samples int

int, number of sequences to generate, default 1

1
gc tuple

tuple, GC content range, default (0,1)

(0, 1)
N_ratio float

float, include N base in the generated sequence, default 0.0

0.0
padding_size int

int, padding size for sequence length, default 0

0
seed int

int, random seed, default None

None
label_func callable

A function that generates a label from a sequence.

None
append bool

bool, append the random generated data to the existed dataset or use the data as a dataset

False
Source code in dnallm/datasets/data.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def random_generate(self, minl: int, maxl: int = 0, samples: int = 1,
                          gc: tuple = (0,1), N_ratio: float = 0.0,
                          padding_size: int = 0, seed: int = None,
                          label_func = None, append: bool = False):
    """
    Replace the current dataset with randomly generated DNA sequences.

    Args:
        minl: int, minimum length of the sequences
        maxl: int, maximum length of the sequences, default is the same as minl
        samples: int, number of sequences to generate, default 1
        gc: tuple, GC content range, default (0,1)
        N_ratio: float, include N base in the generated sequence, default 0.0
        padding_size: int, padding size for sequence length, default 0
        seed: int, random seed, default None
        label_func (callable, optional): A function that generates a label from a sequence.
        append: bool, append the random generated data to the existed dataset or use the data as a dataset
    """
    def process(minl, maxl, number, gc, N_ratio, padding_size, seed, label_func):
        sequences = random_generate_sequences(minl=minl, maxl=maxl, samples=number,
                                            gc=gc, N_ratio=N_ratio,
                                            padding_size=padding_size, seed=seed)
        labels = []
        for seq in sequences:
            labels.append(label_func(seq) if label_func else 0)
        random_ds = Dataset.from_dict({"sequence": sequences, "labels": labels})
        return random_ds
    if append:
        if isinstance(self.dataset, DatasetDict):
            for dt in self.dataset:
                number = round(samples * len(self.dataset[dt]) / sum(self.__len__().values()))
                random_ds = process(minl, maxl, number, gc, N_ratio, padding_size, seed, label_func)
                self.dataset[dt] = concatenate_datasets([self.dataset[dt], random_ds])
        else:
            random_ds = process(minl, maxl, samples, gc, N_ratio, padding_size, seed, label_func)
            self.dataset = concatenate_datasets([self.dataset, random_ds])
    else:
        self.dataset = process(minl, maxl, samples, gc, N_ratio, padding_size, seed, label_func)

raw_reverse_complement(ratio=0.5, seed=None)

Do reverse complement of sequences in the dataset.

Parameters:

Name Type Description Default
ratio float

Ratio of sequences to reverse complement.

0.5
seed int

Random seed for reproducibility.

None
Source code in dnallm/datasets/data.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
def raw_reverse_complement(self, ratio: float = 0.5, seed: int = None):
    """
    Do reverse complement of sequences in the dataset.

    Args:
        ratio (float): Ratio of sequences to reverse complement.
        seed (int): Random seed for reproducibility.
    """
    def process(ds, ratio, seed):
        random.seed(seed)
        number = len(ds["sequence"])
        idxlist = set(random.sample(range(number), int(number * ratio)))
        def concat_fn(example, idx):
            rc = reverse_complement(example["sequence"])
            if idx in idxlist:
                example["sequence"] = rc
            return example
        # Create a dataset with random reverse complement.
        ds.map(concat_fn, with_indices=True, desc="Reverse complementary")
        return ds
    if isinstance(self.dataset, DatasetDict):
        for dt in self.dataset:
            self.dataset[dt] = process(self.dataset[dt], ratio, seed)
    else:
        self.dataset = process(self.dataset, ratio, seed)

sampling(ratio=1.0, seed=None, overwrite=False)

Randomly sample a fraction of the dataset.

Parameters:

Name Type Description Default
ratio float

Fraction of the dataset to sample. Default is 1.0 (no sampling).

1.0
seed int

Random seed for reproducibility.

None
overwrite bool

Whether to overwrite the original dataset with the sampled one.

False

Returns:

Type Description
any

A sampled dataset.

Source code in dnallm/datasets/data.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
def sampling(self, ratio: float=1.0, seed: int = None, overwrite: bool=False) -> any:
    """
    Randomly sample a fraction of the dataset.

    Args:
        ratio (float): Fraction of the dataset to sample. Default is 1.0 (no sampling).
        seed (int): Random seed for reproducibility.
        overwrite (bool): Whether to overwrite the original dataset with the sampled one.

    Returns:
        A sampled dataset.
    """
    dataset = self.dataset
    if isinstance(dataset, DatasetDict):
        for dt in dataset.keys():
            random.seed(seed)
            random_idx = random.sample(range(len(dataset[dt])), int(len(dataset[dt]) * ratio))
            dataset[dt] = dataset[dt].select(random_idx)
    else:
        random_idx = random.sample(range(len(dataset)), int(len(dataset) * ratio))
        dataset = dataset.select(random_idx)
    if overwrite:
        self.dataset = dataset
    else:
        return dataset

show(head=10)

Display the dataset

Parameters:

Name Type Description Default
head int

Number of samples to display.

10
Source code in dnallm/datasets/data.py
600
601
602
603
604
605
606
607
def show(self, head: int=10):
    """
    Display the dataset

    Args:
        head (int): Number of samples to display.
    """
    self.head(head=head, show=True)            

shuffle(seed=None)

Shuffle the dataset.

Parameters:

Name Type Description Default
seed int

Random seed for reproducibility.

None
Source code in dnallm/datasets/data.py
390
391
392
393
394
395
396
397
def shuffle(self, seed: int = None):
    """
    Shuffle the dataset.

    Args:
        seed (int): Random seed for reproducibility.
    """
    self.dataset.shuffle(seed=seed)

split_data(test_size=0.2, val_size=0.1, seed=None)

Split the dataset into train, test, and validation sets.

Parameters:

Name Type Description Default
test_size float

Proportion of the dataset to include in the test split.

0.2
val_size float

Proportion of the dataset to include in the validation split.

0.1
seed int

Random seed for reproducibility.

None
Source code in dnallm/datasets/data.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def split_data(self, test_size: float = 0.2, val_size: float = 0.1, seed: int = None):
    """
    Split the dataset into train, test, and validation sets.

    Args:
        test_size (float): Proportion of the dataset to include in the test split.
        val_size (float): Proportion of the dataset to include in the validation split.
        seed (int): Random seed for reproducibility.
    """
    # First, split off test+validation from training data
    split_result = self.dataset.train_test_split(test_size=test_size + val_size, seed=seed)
    train_ds = split_result['train']
    temp_ds = split_result['test']
    # Further split temp_ds into test and validation sets
    if val_size > 0:
        rel_val_size = val_size / (test_size + val_size)
        temp_split = temp_ds.train_test_split(test_size=rel_val_size, seed=seed)
        test_ds = temp_split['train']
        val_ds = temp_split['test']
        self.dataset = DatasetDict({'train': train_ds, 'test': test_ds, 'val': val_ds})
    else:
        self.dataset = DatasetDict({'train': train_ds, 'test': test_ds})

validate_sequences(minl=20, maxl=6000, gc=(0, 1), valid_chars='ACGTN')

Filter the dataset to keep sequences containing valid DNA bases or allowed length.

Parameters:

Name Type Description Default
minl int

Minimum length of the sequences.

20
maxl int

Maximum length of the sequences.

6000
gc tuple

GC content range between 0 and 1.

(0, 1)
valid_chars str

Allowed characters in the sequences.

'ACGTN'
Source code in dnallm/datasets/data.py
399
400
401
402
403
404
405
406
407
408
409
410
411
def validate_sequences(self, minl: int = 20, maxl: int = 6000, gc: tuple = (0,1), valid_chars: str = "ACGTN"):
    """
    Filter the dataset to keep sequences containing valid DNA bases or allowed length.

    Args:
        minl (int): Minimum length of the sequences.
        maxl (int): Maximum length of the sequences.
        gc (tuple): GC content range between 0 and 1.
        valid_chars (str): Allowed characters in the sequences.
    """
    self.dataset = self.dataset.filter(
        lambda example: check_sequence(example["sequence"], minl, maxl, gc, valid_chars)
    )