import streamlit as st
import mlui.classes.data as data
import mlui.classes.errors as errors
import mlui.classes.model as model
import mlui.enums as enums
[docs]
def set_features_ui(data: data.Data, model: model.Model) -> None:
"""Generate the UI for setting the data columns as features for the input and
output layers of the model.
Parameters
----------
data : Data
Data object.
model : Model
Model object.
"""
st.header("Set Input/Output Features")
st.markdown(
"Choose input and output features for each respective layer from the data "
"columns. Please be aware that the order in which you add the columns is "
"important for evaluating the model or making predictions, as the data needs "
"to be consistent with the data on which the model was trained. Additionally, "
"note that for multiclass classification problems, the output columns should "
"be one-hot encoded for the model to work correctly."
)
task = st.session_state.get("task")
if task != "Predict":
side = st.selectbox("Select layer's type:", ("Input", "Output"))
else:
side = st.selectbox("Select layer's type:", ("Input",))
if side == "Input":
at = "input"
layers = model.inputs
shapes = model.input_shape
else:
at = "output"
layers = model.outputs
shapes = model.output_shape
layer = str(st.selectbox("Select layer:", layers))
available = data.get_unused_columns()
default = model.get_features(layer, at)
available.extend(default)
columns = st.multiselect(
"Select columns (order is important):",
available,
default,
max_selections=shapes[layer],
)
def set_features() -> None:
"""Supporting function for the accurate representation of widgets."""
try:
model.set_features(layer, columns, at)
data.set_unused_columns(available, columns)
st.toast("Features are set!", icon="✅")
except errors.SetError as error:
st.toast(error, icon="❌")
st.button("Set Features", on_click=set_features)
[docs]
def set_callbacks_ui(model: model.Model) -> None:
"""Generate the UI for setting the callbacks for the model.
Parameters
----------
model : Model
Model object.
"""
st.header("Set Callbacks")
st.markdown(
"Optionally choose callbacks for the model to use during evaluation, training, "
"or making predictions. Some callbacks have adjustable parameters. Once you "
"add a callback, you may delete it if you no longer need it or want to "
"readjust its parameters."
)
callbacks = enums.callbacks.classes
entity = str(st.selectbox("Select callback's class:", callbacks))
with st.expander("Callback's Parameters"):
prototype = enums.callbacks.widgets[entity]
widget = prototype()
is_set = model.get_callback(entity)
label = "Set Callback" if not is_set else "Delete Callback"
def manage_callback() -> None:
"""Supporting function for the accurate representation of widgets."""
if not is_set:
try:
params = widget.params
model.set_callback(entity, params)
st.toast("Callback is set!", icon="✅")
except errors.SetError as error:
st.toast(error, icon="❌")
elif is_set:
model.delete_callback(entity)
st.toast("Callback is deleted!", icon="✅")
st.button(label, on_click=manage_callback)