Skip to content

EVO

dnallm.models.special.evo

Classes

EvoTokenizerWrapper

EvoTokenizerWrapper(
    raw_tokenizer, model_max_length=8192, **kwargs
)

raw_tokenizer: Raw CharLevelTokenizer instance from EVO2 package pad_token_id: Token ID used for padding (usually 1 for EVO2) model_max_length: Maximum context length of the model

Source code in dnallm/models/special/evo.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(self, raw_tokenizer, model_max_length=8192, **kwargs):
    """
    raw_tokenizer: Raw CharLevelTokenizer instance from EVO2 package
    pad_token_id: Token ID used for padding (usually 1 for EVO2)
    model_max_length: Maximum context length of the model
    """

    self.raw_tokenizer = raw_tokenizer
    self.model_max_length = model_max_length
    for attr in [
        "vocab_size",
        "bos_token_id",
        "eos_token_id",
        "unk_token_id",
        "pad_token_id",
        "pad_id",
        "eos_id",
        "eod_id",
    ]:
        if hasattr(raw_tokenizer, attr):
            setattr(self, attr, getattr(raw_tokenizer, attr))
    if not hasattr(self, "pad_token_id"):
        self.pad_token_id = self.raw_tokenizer.pad_id
    self.pad_token = raw_tokenizer.decode_token(self.pad_token_id)
    self.padding_side = "right"
    self.init_kwargs = kwargs
Functions
__call__
__call__(
    text,
    padding=False,
    truncation=False,
    max_length=None,
    return_tensors=None,
    **kwargs,
)

call method to tokenize inputs with padding and truncation.

Source code in dnallm/models/special/evo.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __call__(
    self,
    text: str | list[str],
    padding: bool | str = False,
    truncation: bool = False,
    max_length: int | None = None,
    return_tensors: str | None = None,
    **kwargs,
):
    """
    __call__ method to tokenize inputs with padding and truncation.
    """
    if isinstance(text, str):
        text = [text]
        is_batched = False
    else:
        is_batched = True

    input_ids_list = [self.raw_tokenizer.tokenize(seq) for seq in text]

    if truncation:
        limit = (
            max_length if max_length is not None else self.model_max_length
        )
        input_ids_list = [ids[:limit] for ids in input_ids_list]

    if padding:
        if padding == "max_length":
            target_len = (
                max_length
                if max_length is not None
                else self.model_max_length
            )
        elif padding is True or padding == "longest":
            target_len = max(len(ids) for ids in input_ids_list)
        else:
            target_len = max(len(ids) for ids in input_ids_list)

        padded_input_ids = []
        attention_masks = []

        for ids in input_ids_list:
            current_len = len(ids)
            pad_len = target_len - current_len

            if pad_len < 0:
                ids = ids[:target_len]
                pad_len = 0
                current_len = target_len

            new_ids = ids + [self.pad_token_id] * pad_len
            padded_input_ids.append(new_ids)

            mask = [1] * current_len + [0] * pad_len
            attention_masks.append(mask)
    else:
        padded_input_ids = input_ids_list
        attention_masks = [[1] * len(ids) for ids in input_ids_list]

    if return_tensors == "pt":
        return BatchEncoding({
            "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(
                attention_masks, dtype=torch.long
            ),
        })

    result = {
        "input_ids": padded_input_ids,
        "attention_mask": attention_masks,
    }

    if not is_batched and return_tensors is None:
        return {k: v[0] for k, v in result.items()}

    return BatchEncoding(result)

Functions