def __init__(
self,
max_length: int | None = 196_608,
padding_side: str = "right",
return_embeds: bool = False,
embeds_transpose: bool = False,
):
self.max_length = max_length
self.padding_side = padding_side
self.return_embeds = return_embeds
self.embeds_transpose = embeds_transpose
self.token_to_id = {
"A": 0,
"a": 0,
"C": 1,
"c": 1,
"G": 2,
"g": 2,
"T": 3,
"t": 3,
"U": 3,
"u": 3, # RNA
"N": 4,
"n": 4,
"X": 4,
"x": 4,
"-": -1,
".": -1, # Padding
}
self.id_to_token = {0: "A", 1: "C", 2: "G", 3: "T", 4: "N", -1: "-"}
self.pad_token_id = -1
self.pad_token = "-" # noqa: S105
self.unk_token_id = 4 # N
self.vocab_vectors = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0], # 0: A
[0.0, 1.0, 0.0, 0.0], # 1: C
[0.0, 0.0, 1.0, 0.0], # 2: G
[0.0, 0.0, 0.0, 1.0], # 3: T / U
[0.25, 0.25, 0.25, 0.25], # 4: N / X
[0.0, 0.0, 0.0, 0.0], # 5: Padding, index=-1
],
dtype=torch.float32,
)