Source code for mlui.tools.model
import mlui.types.classes as t
from mlui.classes import errors
[docs]
def validate_shapes(shapes: t.Shapes) -> None:
"""
Validate the shapes of a model.
Parameters
----------
shapes : dict of tuples, list of tuples or tuple
Shapes to be validated. Single shape is a `tuple` of `(None, int)`.
Raises
------
ValidateModelError
If the shapes are empty. If any individual shape is empty. If any individual
shape contains more than 2 dimensions.
"""
if not shapes:
raise errors.ValidateModelError("The model's shapes are empty!")
if isinstance(shapes, dict):
shapes = list(shapes.values())
elif isinstance(shapes, tuple):
shapes = [shapes]
for shape in shapes:
if not shapes:
raise errors.ValidateModelError(
"At least one of the model's shapes is empty!"
)
if len(shape) > 2:
raise errors.ValidateModelError(
"At least one of the model's shapes contains more than 2 dimensions!"
)