LoRA Inference
In [2]:
Copied!
from dnallm import load_config, load_model_and_tokenizer
from dnallm import DNAInference
from dnallm import load_config, load_model_and_tokenizer
from dnallm import DNAInference
In [3]:
Copied!
# Load the config file
configs = load_config("inference_config.yaml")
# Load the config file
configs = load_config("inference_config.yaml")
In [4]:
Copied!
# Load the model and tokenizer
model_name = "kuleshov-group/PlantCAD2-Small-l24-d0768"
# from Hugging Face
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")
# Load the model and tokenizer
model_name = "kuleshov-group/PlantCAD2-Small-l24-d0768"
# from Hugging Face
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# from ModelScope
# model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="modelscope")
Fetching 10 files: 0%| | 0/10 [00:00<?, ?it/s]
15:24:53 - dnallm.utils.support - INFO - Model files are stored in /home/liuguanqing/.cache/huggingface/hub/models--kuleshov-group--PlantCAD2-Small-l24-d0768/snapshots/f756c255cb76e9f538c3acec04acf4214ed03fb3
Some weights of the model checkpoint at /home/liuguanqing/.cache/huggingface/hub/models--kuleshov-group--PlantCAD2-Small-l24-d0768/snapshots/f756c255cb76e9f538c3acec04acf4214ed03fb3 were not used when initializing CaduceusForSequenceClassification: ['lm_head.complement_map', 'lm_head.lm_head.weight'] - This IS expected if you are initializing CaduceusForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing CaduceusForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of CaduceusForSequenceClassification were not initialized from the model checkpoint at /home/liuguanqing/.cache/huggingface/hub/models--kuleshov-group--PlantCAD2-Small-l24-d0768/snapshots/f756c255cb76e9f538c3acec04acf4214ed03fb3 and are newly initialized: ['score.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [5]:
Copied!
# Create inference engine
lora_adapter_path = "plantcad/cross_species_acr_train_on_arabidopsis_plantcad2_small"
inference_engine = DNAInference(
model=model,
tokenizer=tokenizer,
config=configs,
lora_adapter=lora_adapter_path
)
# Create inference engine
lora_adapter_path = "plantcad/cross_species_acr_train_on_arabidopsis_plantcad2_small"
inference_engine = DNAInference(
model=model,
tokenizer=tokenizer,
config=configs,
lora_adapter=lora_adapter_path
)
Fetching 9 files: 0%| | 0/9 [00:00<?, ?it/s]
15:25:33 - dnallm.utils.support - INFO - Model files are stored in /home/liuguanqing/.cache/huggingface/hub/models--plantcad--cross_species_acr_train_on_arabidopsis_plantcad2_small/snapshots/a76b001f89a3be2e12990158bb1903ccddd17625 15:25:46 - dnallm.utils.support - INFO - Loaded LoRA adapter from plantcad/cross_species_acr_train_on_arabidopsis_plantcad2_small 15:25:46 - dnallm.utils.support - INFO - Using device: cuda
In [6]:
Copied!
seqs = [(
"AAAAATTTAAATATCGTCTGTAGATATTTTATGGGATGCTTTGAGAATGGGCTTCGTTTTAATGGGCCTC"
"CTCTGCAATCATTGTCCAGAGTCGAGAAACCACCTCTTCTTCTCTTGTTCTTTCTCCAAATCGATTTGGT"
"CCCAACTCTCTTCAAGCAAAGGAGAGATATGAAAATGAAAGCTCTTACGGCGAACAAGTTTTTCCGATTG"
"AAGAAGAGAAGAATCTAGAAGATGAAGACAACACTAGTGCACCAAACAGTTTTGCGCGTCTTGAGAGGAA"
"ACAAAAAACTATTCAGAGTTCAGAGAGAGTCAACCCCCAAACGAGACTTAAACGATGAGCCCACTATAAT"
"TTTATAATTTATGGGCCATCAGGCCCAAATGATCAGTAGTAGTTATTATTTGACTTTTGACATGGTGGAT"
"TTGGTTTAACCACCAAACCGAACGAGTAAAACACTATTGGATTGGGTGATGATATCCCGGTTTTATTTGG"
"TTAAAATCACAAAATCCTGATTTTGGTTCGCGGCTTGATTCTGCCGCTCTCTCGTCTTTAACCTAACTAA"
"AGACGTAGAATGATTCTGGTTATTGAATTAGTTTGATACA"
)]
results = inference_engine.infer_seqs(seqs)
seqs = [(
"AAAAATTTAAATATCGTCTGTAGATATTTTATGGGATGCTTTGAGAATGGGCTTCGTTTTAATGGGCCTC"
"CTCTGCAATCATTGTCCAGAGTCGAGAAACCACCTCTTCTTCTCTTGTTCTTTCTCCAAATCGATTTGGT"
"CCCAACTCTCTTCAAGCAAAGGAGAGATATGAAAATGAAAGCTCTTACGGCGAACAAGTTTTTCCGATTG"
"AAGAAGAGAAGAATCTAGAAGATGAAGACAACACTAGTGCACCAAACAGTTTTGCGCGTCTTGAGAGGAA"
"ACAAAAAACTATTCAGAGTTCAGAGAGAGTCAACCCCCAAACGAGACTTAAACGATGAGCCCACTATAAT"
"TTTATAATTTATGGGCCATCAGGCCCAAATGATCAGTAGTAGTTATTATTTGACTTTTGACATGGTGGAT"
"TTGGTTTAACCACCAAACCGAACGAGTAAAACACTATTGGATTGGGTGATGATATCCCGGTTTTATTTGG"
"TTAAAATCACAAAATCCTGATTTTGGTTCGCGGCTTGATTCTGCCGCTCTCTCGTCTTTAACCTAACTAA"
"AGACGTAGAATGATTCTGGTTATTGAATTAGTTTGATACA"
)]
results = inference_engine.infer_seqs(seqs)
Encoding inputs: 0%| | 0/1 [00:00<?, ? examples/s]
Inferring: 100%|██████████| 1/1 [00:33<00:00, 33.73s/it]
In [7]:
Copied!
print(results)
print(results)
{0: {'sequence': 'AAAAATTTAAATATCGTCTGTAGATATTTTATGGGATGCTTTGAGAATGGGCTTCGTTTTAATGGGCCTCCTCTGCAATCATTGTCCAGAGTCGAGAAACCACCTCTTCTTCTCTTGTTCTTTCTCCAAATCGATTTGGTCCCAACTCTCTTCAAGCAAAGGAGAGATATGAAAATGAAAGCTCTTACGGCGAACAAGTTTTTCCGATTGAAGAAGAGAAGAATCTAGAAGATGAAGACAACACTAGTGCACCAAACAGTTTTGCGCGTCTTGAGAGGAAACAAAAAACTATTCAGAGTTCAGAGAGAGTCAACCCCCAAACGAGACTTAAACGATGAGCCCACTATAATTTTATAATTTATGGGCCATCAGGCCCAAATGATCAGTAGTAGTTATTATTTGACTTTTGACATGGTGGATTTGGTTTAACCACCAAACCGAACGAGTAAAACACTATTGGATTGGGTGATGATATCCCGGTTTTATTTGGTTAAAATCACAAAATCCTGATTTTGGTTCGCGGCTTGATTCTGCCGCTCTCTCGTCTTTAACCTAACTAAAGACGTAGAATGATTCTGGTTATTGAATTAGTTTGATACA', 'label': 'positive', 'scores': {'negative': 0.06456661224365234, 'positive': 0.9354333281517029}}}
In [8]:
Copied!
infer_file = "./test.csv"
results, metrics = inference_engine.infer_file(
infer_file, seq_col="sequence", label_col="label", evaluate=True
)
infer_file = "./test.csv"
results, metrics = inference_engine.infer_file(
infer_file, seq_col="sequence", label_col="label", evaluate=True
)
Encoding inputs: 0%| | 0/250 [00:00<?, ? examples/s]
Inferring: 100%|██████████| 16/16 [00:08<00:00, 1.86it/s]
In [9]:
Copied!
for i, res in results.items():
print(res['label'], res['scores'], sep="\n")
break
print(metrics)
for i, res in results.items():
print(res['label'], res['scores'], sep="\n")
break
print(metrics)
negative
{'negative': 0.9995535016059875, 'positive': 0.000446514313807711}
{'accuracy': 0.816, 'precision': 0.8888888888888888, 'recall': 0.26666666666666666, 'f1': 0.41025641025641024, 'mcc': 0.423204406266162, 'AUROC': 0.8621929824561403, 'AUPRC': 0.7100068048890062, 'TPR': 0.26666666666666666, 'TNR': 0.9894736842105263, 'FPR': 0.010526315789473684, 'FNR': 0.7333333333333333}
In [ ]:
Copied!