Source code for model_registry._client

"""Standard client for the model registry."""

from __future__ import annotations

import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .core import ModelRegistryAPIClient
from .exceptions import StoreError
from .types import (
    ListOptions,
    ModelArtifact,
    ModelVersion,
    Pager,
    RegisteredModel,
    SupportedTypes,
)

ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)


[docs] class ModelRegistry: """Model registry client.""" def __init__( self, server_address: str, port: int = 443, *, author: str, is_secure: bool = True, user_token: str | None = None, custom_ca: str | None = None, ): """Constructor. Args: server_address: Server address. port: Server port. Defaults to 443. Keyword Args: author: Name of the author. is_secure: Whether to use a secure connection. Defaults to True. user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH. custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT. """ import nest_asyncio nest_asyncio.apply() # TODO: get remaining args from env self._author = author if not user_token: # /var/run/secrets/kubernetes.io/serviceaccount/token sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH") if sa_token: user_token = Path(sa_token).read_text() else: warn("User access token is missing", stacklevel=2) if is_secure: root_ca = None if not custom_ca: if cert := os.getenv("CERT"): root_ca = cert # client might have a default CA setup else: root_ca = custom_ca if not user_token: msg = "user token must be provided for secure connection" raise StoreError(msg) self._api = ModelRegistryAPIClient.secure_connection( server_address, port, user_token=user_token, custom_ca=root_ca ) elif custom_ca: msg = "Custom CA provided without secure connection, conflicting options" raise StoreError(msg) else: self._api = ModelRegistryAPIClient.insecure_connection( server_address, port, user_token ) def async_runner(self, coro: Any) -> Any: import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(coro) async def _register_model(self, name: str, **kwargs) -> RegisteredModel: if rm := await self._api.get_registered_model_by_params(name): return rm return await self._api.upsert_registered_model( RegisteredModel(name=name, **kwargs) ) async def _register_new_version( self, rm: RegisteredModel, version: str, author: str, /, **kwargs ) -> ModelVersion: assert rm.id is not None, "Registered model must have an ID" if await self._api.get_model_version_by_params(rm.id, version): msg = f"Version {version} already exists" raise StoreError(msg) return await self._api.upsert_model_version( ModelVersion(name=version, author=author, **kwargs), rm.id ) async def _register_model_artifact( self, mv: ModelVersion, name: str, uri: str, /, **kwargs ) -> ModelArtifact: assert mv.id is not None, "Model version must have an ID" return await self._api.upsert_model_version_artifact( ModelArtifact(name=name, uri=uri, **kwargs), mv.id )
[docs] def register_model( self, name: str, uri: str, *, model_format_name: str, model_format_version: str, version: str, storage_key: str | None = None, storage_path: str | None = None, service_account_name: str | None = None, author: str | None = None, owner: str | None = None, description: str | None = None, metadata: Mapping[str, SupportedTypes] | None = None, ) -> RegisteredModel: """Register a model. This registers a model in the model registry. The model is not downloaded, and has to be stored prior to registration. Most models can be registered using their URI, along with optional connection-specific parameters, `storage_key` and `storage_path` or, simply a `service_account_name`. URI builder utilities are recommended when referring to specialized storage; for example `utils.s3_uri_from` helper when using S3 object storage data connections. Args: name: Name of the model. uri: URI of the model. Keyword Args: version: Version of the model. Has to be unique. model_format_name: Name of the model format. model_format_version: Version of the model format. description: Description of the model. author: Author of the model. Defaults to the client author. owner: Owner of the model. Defaults to the client author. storage_key: Storage key. storage_path: Storage path. service_account_name: Service account name. metadata: Additional version metadata. Defaults to values returned by `default_metadata()`. Returns: Registered model. """ rm = self.async_runner(self._register_model(name, owner=owner or self._author)) mv = self.async_runner( self._register_new_version( rm, version, author or self._author, description=description, custom_properties=metadata or {}, ) ) self.async_runner( self._register_model_artifact( mv, name, uri, model_format_name=model_format_name, model_format_version=model_format_version, storage_key=storage_key, storage_path=storage_path, service_account_name=service_account_name, ) ) return rm
[docs] def update(self, model: TModel) -> TModel: """Update a model.""" if not model.id: msg = "Model must have an ID" raise StoreError(msg) if not isinstance(model, get_args(ModelTypes)): msg = f"Model must be one of {get_args(ModelTypes)}" raise StoreError(msg) if isinstance(model, RegisteredModel): return self.async_runner(self._api.upsert_registered_model(model)) if isinstance(model, ModelVersion): return self.async_runner(self._api.upsert_model_version(model, None)) return self.async_runner(self._api.upsert_model_artifact(model))
[docs] def register_hf_model( self, repo: str, path: str, *, version: str, model_format_name: str, model_format_version: str, author: str | None = None, owner: str | None = None, model_name: str | None = None, description: str | None = None, git_ref: str = "main", ) -> RegisteredModel: """Register a Hugging Face model. This imports a model from Hugging Face hub and registers it in the model registry. Note that the model is not downloaded. Args: repo: Name of the repository from Hugging Face hub. path: URI of the model. Keyword Args: version: Version of the model. Has to be unique. model_format_name: Name of the model format. model_format_version: Version of the model format. author: Author of the model. Defaults to repo owner. owner: Owner of the model. Defaults to the client author. model_name: Name of the model. Defaults to the repo name. description: Description of the model. git_ref: Git reference to use. Defaults to `main`. Returns: Registered model. """ try: from huggingface_hub import HfApi, hf_hub_url, utils except ImportError as e: msg = """package `huggingface-hub` is not installed. To import models from Hugging Face Hub, start by installing the `huggingface-hub` package, either directly or as an extra (available as `model-registry[hf]`), e.g.: ```sh !pip install --pre model-registry[hf] ``` or ```sh !pip install huggingface-hub ``` """ raise StoreError(msg) from e api = HfApi() try: model_info = api.model_info(repo, revision=git_ref) except utils.RepositoryNotFoundError as e: msg = f"Repository {repo} does not exist" raise StoreError(msg) from e except utils.RevisionNotFoundError as e: # TODO: as all hf-hub client calls default to using main, should we provide a tip? msg = f"Revision {git_ref} does not exist" raise StoreError(msg) from e if not author: # model author can be None if the repo is in a "global" namespace (i.e. no / in repo). if model_info.author is None: model_author = "unknown" warn( "Model author is unknown. This is likely because the model is in a global namespace.", stacklevel=2, ) else: model_author = model_info.author else: model_author = author source_uri = hf_hub_url(repo, path, revision=git_ref) metadata = { "repo": repo, "source_uri": source_uri, "model_origin": "huggingface_hub", "model_author": model_author, } # card_data is the new field, but let's use the old one for backwards compatibility. if card_data := model_info.cardData: metadata.update( { k: v for k, v in card_data.to_dict().items() # TODO: (#151) preserve tags, possibly other complex metadata if isinstance(v, get_args(SupportedTypes)) } ) return self.register_model( model_name or model_info.id, source_uri, author=author or model_author, owner=owner or self._author, version=version, model_format_name=model_format_name, model_format_version=model_format_version, description=description, storage_path=path, metadata=metadata, )
[docs] def get_registered_model(self, name: str) -> RegisteredModel | None: """Get a registered model. Args: name: Name of the model. Returns: Registered model. """ return self.async_runner(self._api.get_registered_model_by_params(name))
[docs] def get_model_version(self, name: str, version: str) -> ModelVersion | None: """Get a model version. Args: name: Name of the model. version: Version of the model. Returns: Model version. Raises: StoreException: If the model does not exist. """ if not (rm := self.get_registered_model(name)): msg = f"Model {name} does not exist" raise StoreError(msg) assert rm.id return self.async_runner(self._api.get_model_version_by_params(rm.id, version))
[docs] def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None: """Get a model artifact. Args: name: Name of the model. version: Version of the model. Returns: Model artifact. Raises: StoreException: If either the model or the version don't exist. """ if not (mv := self.get_model_version(name, version)): msg = f"Version {version} does not exist" raise StoreError(msg) assert mv.id return self.async_runner(self._api.get_model_artifact_by_params(name, mv.id))
[docs] def get_registered_models(self) -> Pager[RegisteredModel]: """Get a pager for registered models. Returns: Iterable pager for registered models. """ def rm_list(options: ListOptions) -> list[RegisteredModel]: return self.async_runner(self._api.get_registered_models(options)) return Pager[RegisteredModel](rm_list)
[docs] def get_model_versions(self, name: str) -> Pager[ModelVersion]: """Get a pager for model versions. Args: name: Name of the model. Returns: Iterable pager for model versions. Raises: StoreException: If the model does not exist. """ if not (rm := self.get_registered_model(name)): msg = f"Model {name} does not exist" raise StoreError(msg) def rm_versions(options: ListOptions) -> list[ModelVersion]: # type checkers can't restrict the type inside a nested function: https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions assert rm.id return self.async_runner(self._api.get_model_versions(rm.id, options)) return Pager[ModelVersion](rm_versions)