Multi-Label Classification
In [1]:
Copied!
from dnallm import load_config, load_model_and_tokenizer, DNADataset, DNATrainer
from dnallm import load_config, load_model_and_tokenizer, DNADataset, DNATrainer
In [2]:
Copied!
# Load the config file
configs = load_config("./multi_labels_config.yaml")
# Load the config file
configs = load_config("./multi_labels_config.yaml")
In [3]:
Copied!
# Load the model and tokenizer
model_name = "zhangtaolab/plant-dnagpt-BPE"
# from Hugging Face
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")
# Load the model and tokenizer
model_name = "zhangtaolab/plant-dnagpt-BPE"
# from Hugging Face
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")
Downloading Model from https://www.modelscope.cn to directory: /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE 14:37:17 - dnallm.utils.support - INFO - Model files are stored in /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at /Users/forrest/.cache/modelscope/hub/models/zhangtaolab/plant-dnagpt-BPE and are newly initialized: ['score.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [4]:
Copied!
# Load the datasets
## multiple labels are separated by ','
datasets = DNADataset.load_local_data("./maize_test.tsv", seq_col="sequence", label_col="labels", multi_label_sep=",", tokenizer=tokenizer, max_length=512)
# Encode the sequences with given task's data collator
datasets.encode_sequences(task=configs['task'].task_type, remove_unused_columns=True)
# Split the dataset into train, test, and validation sets
datasets.split_data()
# Load the datasets
## multiple labels are separated by ','
datasets = DNADataset.load_local_data("./maize_test.tsv", seq_col="sequence", label_col="labels", multi_label_sep=",", tokenizer=tokenizer, max_length=512)
# Encode the sequences with given task's data collator
datasets.encode_sequences(task=configs['task'].task_type, remove_unused_columns=True)
# Split the dataset into train, test, and validation sets
datasets.split_data()
Generating train split: 0 examples [00:00, ? examples/s]
Format labels: 0%| | 0/10000 [00:00<?, ? examples/s]
Encoding inputs: 0%| | 0/10000 [00:00<?, ? examples/s]
In [5]:
Copied!
# Initialize the trainer
trainer = DNATrainer(
model=model,
config=configs,
datasets=datasets
)
# Initialize the trainer
trainer = DNATrainer(
model=model,
config=configs,
datasets=datasets
)
In [6]:
Copied!
# Start training
metrics = trainer.train()
print(metrics)
# Start training
metrics = trainer.train()
print(metrics)
[2625/2625 1:16:47, Epoch 3/3]
| Step | Training Loss | Validation Loss | Accuracy | Precision | Recall | F1 | Precision Micro | Recall Micro | F1 Micro | Precision Weighted | Recall Weighted | F1 Weighted | Precision Samples | Recall Samples | F1 Samples | Mcc | Auroc | Auprc | Tpr | Tnr | Fpr | Fnr |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 500 | 0.280200 | 0.155547 | 0.762238 | 0.211572 | 0.158561 | 0.158172 | 0.473950 | 0.226870 | 0.306855 | 0.280166 | 0.226870 | 0.220705 | 0.081927 | 0.048246 | 0.047847 | 0.149856 | 0.918668 | 0.361795 | 0.158561 | 0.981821 | 0.018179 | 0.841439 |
| 1000 | 0.146100 | 0.170772 | 0.800200 | 0.441545 | 0.164471 | 0.227142 | 0.585973 | 0.208367 | 0.307418 | 0.514883 | 0.208367 | 0.282476 | 0.048261 | 0.035697 | 0.031921 | 0.236231 | 0.909454 | 0.399514 | 0.164471 | 0.989507 | 0.010493 | 0.835529 |
| 1500 | 0.138600 | 0.151820 | 0.800200 | 0.638011 | 0.133174 | 0.204163 | 0.641509 | 0.164119 | 0.261371 | 0.674379 | 0.164119 | 0.245527 | 0.038927 | 0.023819 | 0.023730 | 0.251443 | 0.929347 | 0.463728 | 0.133174 | 0.993440 | 0.006560 | 0.866826 |
| 2000 | 0.121000 | 0.145796 | 0.790210 | 0.576422 | 0.286482 | 0.362327 | 0.539597 | 0.323411 | 0.404427 | 0.546884 | 0.323411 | 0.393154 | 0.053367 | 0.047860 | 0.039968 | 0.362768 | 0.932501 | 0.474062 | 0.286482 | 0.980389 | 0.019611 | 0.713518 |
| 2500 | 0.111900 | 0.145860 | 0.796204 | 0.643390 | 0.252621 | 0.347409 | 0.597938 | 0.279968 | 0.381370 | 0.612762 | 0.279968 | 0.371620 | 0.046154 | 0.037964 | 0.033512 | 0.366503 | 0.935172 | 0.497237 | 0.252621 | 0.986607 | 0.013393 | 0.747379 |
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
{'train_runtime': 4613.7613, 'train_samples_per_second': 4.551, 'train_steps_per_second': 0.569, 'total_flos': 5487290020528128.0, 'train_loss': 0.1565389651343936, 'epoch': 3.0}
In [7]:
Copied!
# Do prediction on the test set
trainer.infer()
# Do prediction on the test set
trainer.infer()
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/Users/forrest/GitHub/DNALLM/.venv/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1833: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in samples with no true nor predicted labels. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
Out[7]:
PredictionOutput(predictions=array([[-5.443179 , -5.6372533 , -5.552704 , ..., -6.5272865 ,
-7.5997458 , -5.74388 ],
[-1.7592301 , -3.8711321 , -4.3161707 , ..., -2.0341883 ,
-5.250122 , -1.8616908 ],
[-6.446937 , -6.779642 , -6.8808575 , ..., -6.401059 ,
-7.6829267 , -7.433517 ],
...,
[-0.9382036 , -1.0705391 , -0.59772235, ..., -1.354423 ,
-2.1761355 , -0.2642503 ],
[-5.504851 , -6.3220463 , -5.644334 , ..., -6.8325057 ,
-7.718524 , -6.2338963 ],
[-5.6954155 , -5.9499183 , -6.267033 , ..., -5.7308164 ,
-7.274815 , -6.108817 ]], dtype=float32), label_ids=array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[1., 1., 1., ..., 0., 0., 1.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32), metrics={'test_loss': 0.13551659882068634, 'test_accuracy': 0.798, 'test_precision': 0.5518252754926944, 'test_recall': 0.311220457107487, 'test_f1': 0.3903308188112585, 'test_precision_micro': 0.5606896551724138, 'test_recall_micro': 0.3433277027027027, 'test_f1_micro': 0.4258774227344159, 'test_precision_weighted': 0.5585796433599094, 'test_recall_weighted': 0.3433277027027027, 'test_f1_weighted': 0.4181393571382143, 'test_precision_samples': 0.055576365449565915, 'test_recall_samples': 0.04742432422668491, 'test_f1_samples': 0.04118593701834425, 'test_mcc': 0.3822012575166047, 'test_AUROC': 0.9374013561554215, 'test_AUPRC': 0.47993831500493955, 'test_TPR': 0.311220457107487, 'test_TNR': 0.9818909959471704, 'test_FPR': 0.018109004052829662, 'test_FNR': 0.688779542892513, 'test_runtime': 34.2167, 'test_samples_per_second': 58.451, 'test_steps_per_second': 3.653})