fix theme validation

This commit is contained in:
Sina Atalay 2024-02-15 18:09:48 +01:00
parent 238278764a
commit 0b2264792d
1 changed files with 15 additions and 10 deletions

View File

@ -936,6 +936,8 @@ RenderCVDesign = Annotated[
ClassicThemeOptions | ModerncvThemeOptions | McdowellThemeOptions, ClassicThemeOptions | ModerncvThemeOptions | McdowellThemeOptions,
pydantic.Field(discriminator="theme"), pydantic.Field(discriminator="theme"),
] ]
rendercv_design_validator = pydantic.TypeAdapter(RenderCVDesign)
available_themes = ["classic", "moderncv", "mcdowell"]
class RenderCVDataModel(RenderCVBaseModel): class RenderCVDataModel(RenderCVBaseModel):
@ -951,20 +953,16 @@ class RenderCVDataModel(RenderCVBaseModel):
description="The design information of the CV.", description="The design information of the CV.",
) )
@pydantic.field_validator("design") @pydantic.field_validator("design", mode="before")
@classmethod @classmethod
def initialize_if_custom_theme_is_used( def initialize_if_custom_theme_is_used(
cls, design: RenderCVDesign | Any cls, design: RenderCVDesign | Any
) -> RenderCVDesign | Any: ) -> RenderCVDesign | Any:
"""Initialize the custom theme if it is used and validate it. Otherwise, return """Initialize the custom theme if it is used and validate it. Otherwise, return
the built-in theme.""" the built-in theme."""
# `get_args` for an Annotated object returns the arguments when Annotated is if design["theme"] in available_themes: # type: ignore
# used. The first argument is actually the union of the types, so we need to # it is a built-in theme, validate and return it:
# access the first argument to use isinstance function. return rendercv_design_validator.validate_python(design)
theme_data_model_types = get_args(RenderCVDesign)[0]
if isinstance(design, theme_data_model_types):
# it is a built-in theme
return design
else: else:
theme_name: str = design["theme"] # type: ignore theme_name: str = design["theme"] # type: ignore
# check if the theme name is valid: # check if the theme name is valid:
@ -1061,14 +1059,21 @@ def read_input_file(
""" """
# check if the file exists: # check if the file exists:
if not file_path.exists(): if not file_path.exists():
raise FileNotFoundError(f"The input file {file_path} doesn't exist.") raise FileNotFoundError(
f"The input file [magenta]{file_path}[/magenta] doesn't exist!"
)
# check the file extension: # check the file extension:
accepted_extensions = [".yaml", ".yml", ".json", ".json5"] accepted_extensions = [".yaml", ".yml", ".json", ".json5"]
if file_path.suffix not in accepted_extensions: if file_path.suffix not in accepted_extensions:
user_friendly_accepted_extensions = [
f"[green]{ext}[/green]" for ext in accepted_extensions
]
user_friendly_accepted_extensions = ", ".join(user_friendly_accepted_extensions)
raise ValueError( raise ValueError(
"The input file should have one of the following extensions:" "The input file should have one of the following extensions:"
f" {accepted_extensions}. The input file is {file_path}." f" {user_friendly_accepted_extensions}. The input file is"
f" [magenta]{file_path}[/magenta]."
) )
file_content = file_path.read_text(encoding="utf-8") file_content = file_path.read_text(encoding="utf-8")