Inference Demo
In [ ]:
Copied!
import sys
# from os import path
# sys.path.append(path.abspath(path.join(path.dirname(__file__),
# '../../..')))
import marimo as mo
import pandas as pd
from dnallm import load_config, load_model_and_tokenizer, DNAInference
import sys
# from os import path
# sys.path.append(path.abspath(path.join(path.dirname(__file__),
# '../../..')))
import marimo as mo
import pandas as pd
from dnallm import load_config, load_model_and_tokenizer, DNAInference
In [ ]:
Copied!
model_df = pd.read_excel("./plant_DNA_LLMs_finetune_list.xlsx")
model_df = pd.read_excel("./plant_DNA_LLMs_finetune_list.xlsx")
In [ ]:
Copied!
tasks = model_df.Task.unique()
print("Available tasks:", tasks, sep="\n")
tasks = model_df.Task.unique()
print("Available tasks:", tasks, sep="\n")
In [ ]:
Copied!
task_dropdown = mo.ui.dropdown(
tasks,
value='open chromatin',
label='Predict Task'
)
task_dropdown = mo.ui.dropdown(
tasks,
value='open chromatin',
label='Predict Task'
)
In [ ]:
Copied!
models = model_df.Model.unique()
print("Available models:", models, sep="\n")
models = model_df.Model.unique()
print("Available models:", models, sep="\n")
In [ ]:
Copied!
model_dropdown = mo.ui.dropdown(
models,
value='Plant DNABERT',
label='Model'
)
model_dropdown = mo.ui.dropdown(
models,
value='Plant DNABERT',
label='Model'
)
In [ ]:
Copied!
tokenizers = model_df.Tokenzier.unique()
print("Available models:", tokenizers,sep="\n")
tokenizers = model_df.Tokenzier.unique()
print("Available models:", tokenizers,sep="\n")
In [ ]:
Copied!
tokenizer_dropdown = mo.ui.dropdown(
tokenizers,
value='BPE',
label='Tokenizer'
)
tokenizer_dropdown = mo.ui.dropdown(
tokenizers,
value='BPE',
label='Tokenizer'
)
In [ ]:
Copied!
source_dropdown = mo.ui.dropdown({'modelscope':'modelscope',
'huggingface':'huggingface'
}, value='modelscope', label='Model Source')
source_dropdown = mo.ui.dropdown({'modelscope':'modelscope',
'huggingface':'huggingface'
}, value='modelscope', label='Model Source')
In [ ]:
Copied!
placeholder = 'GGGCAGCGGTTACACCTTAATCGACACGACTCTCGGCAACGGATATCTCG\
GCTCTTGCATCGATGAAGAACGTAGCAAAATGCGATACCTGGTGTGAATTGCAGAAT\
CCCGCGAACCATCGAGTTTTTGAACGCAAGTTGCGCCCGAAGCCTTCTGACGGA\
GGGCACGTCTGCCTGGGCGTCACGCCAAAAGACACTCCCAACACCCCCCCGCGGGGC\
GAGGGACGTGGCGTCTGGCCCCCCGCGCTGCAGGGCGAGGTGGGCCGAAGCAGGGGCTGCC\
GGCGAACCGCGTCGGACGCAACACGTGGTGGGCGACATCAAGTTGTTCTCGGTGCAGCGT\
CCCGGCGCGCGGCCGGCCATTCGGCCCTAAGGACCCATCGAGCGACCGAGCTTGCCCTCG\
GACCACGACCCCAGGTCAGTCGGGACTACCCGCTGAGTTTAAGCATATAAATAAGCGGAGGAG\
AAGAAACTTACGAGGATTCCCCTAGTAACGGCGAGCGAACCGGGAGCAGCCCAGCTTGA\
GAATCGGGCGGCCTCGCCGCCCGAATTGTAGTCTGGAGAGGCGT'
dnaseq_entry_box = mo.ui.text_area(
placeholder=placeholder,
full_width=True,
label='DNA Sequence:',
rows=5
)
placeholder = 'GGGCAGCGGTTACACCTTAATCGACACGACTCTCGGCAACGGATATCTCG\
GCTCTTGCATCGATGAAGAACGTAGCAAAATGCGATACCTGGTGTGAATTGCAGAAT\
CCCGCGAACCATCGAGTTTTTGAACGCAAGTTGCGCCCGAAGCCTTCTGACGGA\
GGGCACGTCTGCCTGGGCGTCACGCCAAAAGACACTCCCAACACCCCCCCGCGGGGC\
GAGGGACGTGGCGTCTGGCCCCCCGCGCTGCAGGGCGAGGTGGGCCGAAGCAGGGGCTGCC\
GGCGAACCGCGTCGGACGCAACACGTGGTGGGCGACATCAAGTTGTTCTCGGTGCAGCGT\
CCCGGCGCGCGGCCGGCCATTCGGCCCTAAGGACCCATCGAGCGACCGAGCTTGCCCTCG\
GACCACGACCCCAGGTCAGTCGGGACTACCCGCTGAGTTTAAGCATATAAATAAGCGGAGGAG\
AAGAAACTTACGAGGATTCCCCTAGTAACGGCGAGCGAACCGGGAGCAGCCCAGCTTGA\
GAATCGGGCGGCCTCGCCGCCCGAATTGTAGTCTGGAGAGGCGT'
dnaseq_entry_box = mo.ui.text_area(
placeholder=placeholder,
full_width=True,
label='DNA Sequence:',
rows=5
)
In [ ]:
Copied!
title = mo.md(
"<center><h2>Model inference</h2></center>"
)
hstack=mo.hstack(
[task_dropdown,
model_dropdown,
tokenizer_dropdown,
source_dropdown],
align='center',
justify='center'
)
mo.vstack([title, dnaseq_entry_box, hstack])
title = mo.md(
"<center><h2>Model inference</h2></center>"
)
hstack=mo.hstack(
[task_dropdown,
model_dropdown,
tokenizer_dropdown,
source_dropdown],
align='center',
justify='center'
)
mo.vstack([title, dnaseq_entry_box, hstack])
In [ ]:
Copied!
try:
model_name = model_df[
(model_df.Task == task_dropdown.value)
& (model_df.Model == model_dropdown.value)
& (model_df.Tokenzier == tokenizer_dropdown.value)
].Name.tolist()[0]
print("Current model:", model_name, sep="\n")
callout = ""
except:
callout = mo.callout("Cannot found the model", kind="warn")
model_name = None
mo.vstack([callout], align="stretch")
try:
model_name = model_df[
(model_df.Task == task_dropdown.value)
& (model_df.Model == model_dropdown.value)
& (model_df.Tokenzier == tokenizer_dropdown.value)
].Name.tolist()[0]
print("Current model:", model_name, sep="\n")
callout = ""
except:
callout = mo.callout("Cannot found the model", kind="warn")
model_name = None
mo.vstack([callout], align="stretch")
In [ ]:
Copied!
dnaseq = ''
if dnaseq_entry_box.value:
dnaseq = dnaseq_entry_box.value
else:
dnaseq = placeholder
print("No sequence found, use default sequence.")
dnaseq = ''
if dnaseq_entry_box.value:
dnaseq = dnaseq_entry_box.value
else:
dnaseq = placeholder
print("No sequence found, use default sequence.")
In [ ]:
Copied!
configs = load_config("./inference_config.yaml")
configs = load_config("./inference_config.yaml")
In [ ]:
Copied!
# Set task type
task = task_dropdown.value
if task in ['core promoter', 'sequence conservation', 'enhancer',
'H3K27ac', 'H3K27me3', 'H3K4me3', 'lncRNAs']:
data = task.split()[-1]
configs['task'].task_type = 'binary'
configs['task'].num_labels = 2
configs['task'].label_names = ['Not '+data, data.capitalize()]
elif task in ['open chromatin']:
configs['task'].task_type = 'multiclass'
configs['task'].num_labels = 3
configs['task'].label_names = ['Not '+task, 'Partial '+task, 'Full '+task]
elif task in ['promoter strength leaf', 'promoter strength protoplast']:
configs['task'].task_type = 'regression'
configs['task'].num_labels = 1
configs['task'].label_names = [task]
else:
pass
# Set task type
task = task_dropdown.value
if task in ['core promoter', 'sequence conservation', 'enhancer',
'H3K27ac', 'H3K27me3', 'H3K4me3', 'lncRNAs']:
data = task.split()[-1]
configs['task'].task_type = 'binary'
configs['task'].num_labels = 2
configs['task'].label_names = ['Not '+data, data.capitalize()]
elif task in ['open chromatin']:
configs['task'].task_type = 'multiclass'
configs['task'].num_labels = 3
configs['task'].label_names = ['Not '+task, 'Partial '+task, 'Full '+task]
elif task in ['promoter strength leaf', 'promoter strength protoplast']:
configs['task'].task_type = 'regression'
configs['task'].num_labels = 1
configs['task'].label_names = [task]
else:
pass
In [ ]:
Copied!
if model_name:
# Load the model and tokenizer
model, tokenizer = load_model_and_tokenizer(
model_name,
task_config=configs['task'],
source=source_dropdown.value
)
# Instantiate the inference engine
inference_engine = DNAInference(
model=model,
tokenizer=tokenizer,
config=configs
)
# Predict the sequence
predict_button = mo.ui.button(label="Predict",
on_click=lambda value: inference_engine.infer_seqs(
dnaseq, output_attentions=True)
)
else:
predict_button = mo.ui.button(label="Predict")
inference_engine = None
mo.hstack([predict_button], align='center', justify='center')
if model_name:
# Load the model and tokenizer
model, tokenizer = load_model_and_tokenizer(
model_name,
task_config=configs['task'],
source=source_dropdown.value
)
# Instantiate the inference engine
inference_engine = DNAInference(
model=model,
tokenizer=tokenizer,
config=configs
)
# Predict the sequence
predict_button = mo.ui.button(label="Predict",
on_click=lambda value: inference_engine.infer_seqs(
dnaseq, output_attentions=True)
)
else:
predict_button = mo.ui.button(label="Predict")
inference_engine = None
mo.hstack([predict_button], align='center', justify='center')
In [ ]:
Copied!
if predict_button.value:
results = predict_button.value
else:
results = None
results
if predict_button.value:
results = predict_button.value
else:
results = None
results
In [ ]:
Copied!
if results:
seqs = len(inference_engine.sequences)
layers = len(inference_engine.embeddings['attentions'])
heads = inference_engine.embeddings['attentions'][0].shape[1]
else:
seqs = 1
layers = 12
heads = 12
seq_number = mo.ui.number(start=1, stop=seqs if seqs>0 else 1, label="Sequence index")
layer_slider = mo.ui.slider(start=1, stop=layers, step=1, label='Layer index',
show_value=True)
head_slider = mo.ui.slider(start=1, stop=heads, step=1, label='Head index',
show_value=True)
figure_size = mo.ui.number(start=200, stop=5120, step=10, label='Figure size',
value = 800)
if results:
seqs = len(inference_engine.sequences)
layers = len(inference_engine.embeddings['attentions'])
heads = inference_engine.embeddings['attentions'][0].shape[1]
else:
seqs = 1
layers = 12
heads = 12
seq_number = mo.ui.number(start=1, stop=seqs if seqs>0 else 1, label="Sequence index")
layer_slider = mo.ui.slider(start=1, stop=layers, step=1, label='Layer index',
show_value=True)
head_slider = mo.ui.slider(start=1, stop=heads, step=1, label='Head index',
show_value=True)
figure_size = mo.ui.number(start=200, stop=5120, step=10, label='Figure size',
value = 800)
In [ ]:
Copied!
plot_button = mo.ui.button(label="Plot attention map",
on_click=lambda value: inference_engine.plot_attentions(
seq_number.value-1, layer_slider.value-1, head_slider.value-1
)
)
plot_options = mo.hstack(
[seq_number,
layer_slider,
head_slider,
figure_size],
align='center',
justify='center'
)
mo.vstack([plot_options, plot_button], align='center', justify='center')
plot_button = mo.ui.button(label="Plot attention map",
on_click=lambda value: inference_engine.plot_attentions(
seq_number.value-1, layer_slider.value-1, head_slider.value-1
)
)
plot_options = mo.hstack(
[seq_number,
layer_slider,
head_slider,
figure_size],
align='center',
justify='center'
)
mo.vstack([plot_options, plot_button], align='center', justify='center')
In [ ]:
Copied!
plot_out = plot_button.value
if plot_out:
chart = mo.ui.altair_chart(plot_out).properties(
width=figure_size.value, height=figure_size.value
)
else:
chart = None
mo.vstack([chart], align='center', justify='center')
plot_out = plot_button.value
if plot_out:
chart = mo.ui.altair_chart(plot_out).properties(
width=figure_size.value, height=figure_size.value
)
else:
chart = None
mo.vstack([chart], align='center', justify='center')