LoRA Fine-tuning
In [ ]:
Copied!
from dnallm import load_config, load_model_and_tokenizer, DNADataset, DNATrainer
from dnallm import load_config, load_model_and_tokenizer, DNADataset, DNATrainer
In [ ]:
Copied!
# Load the config file
configs = load_config("./finetune_config.yaml")
# Load the config file
configs = load_config("./finetune_config.yaml")
In [ ]:
Copied!
# Load the model and tokenizer
model_name = "kuleshov-group/PlantCAD2-Small-l24-d0768"
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
# Load the model and tokenizer
model_name = "kuleshov-group/PlantCAD2-Small-l24-d0768"
# from ModelScope
model, tokenizer = load_model_and_tokenizer(model_name, task_config=configs['task'], source="huggingface")
In [ ]:
Copied!
# Load the datasets
data_name = "zhangtaolab/plant-multi-species-core-promoters"
# from ModelScope
datasets = DNADataset.from_modelscope(data_name, seq_col="sequence", label_col="label", tokenizer=tokenizer, max_length=512)
# downsampling datasets
sampled_datasets = datasets.sampling(0.05, overwrite=True)
# Encode the datasets
sampled_datasets.encode_sequences(remove_unused_columns=True)
# Load the datasets
data_name = "zhangtaolab/plant-multi-species-core-promoters"
# from ModelScope
datasets = DNADataset.from_modelscope(data_name, seq_col="sequence", label_col="label", tokenizer=tokenizer, max_length=512)
# downsampling datasets
sampled_datasets = datasets.sampling(0.05, overwrite=True)
# Encode the datasets
sampled_datasets.encode_sequences(remove_unused_columns=True)
In [ ]:
Copied!
# Initialize the trainer with lora
trainer = DNATrainer(
model=model,
config=configs,
datasets=sampled_datasets,
use_lora=True # load lora config from the configs
)
# Initialize the trainer with lora
trainer = DNATrainer(
model=model,
config=configs,
datasets=sampled_datasets,
use_lora=True # load lora config from the configs
)
In [ ]:
Copied!
# Start training
metrics = trainer.train()
print(metrics)
# Start training
metrics = trainer.train()
print(metrics)
In [ ]:
Copied!