Source code for model_registry.core

"""Client for the model registry."""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import TypeVar, cast

from typing_extensions import overload

from mr_openapi import (
    ApiClient,
    Configuration,
    ModelRegistryServiceApi,
)
from mr_openapi import (
    exceptions as mr_exceptions,
)

from ._utils import required_args
from .types import (
    Artifact,
    ListOptions,
    ModelArtifact,
    ModelVersion,
    RegisteredModel,
)

ArtifactT = TypeVar("ArtifactT", bound=Artifact)


[docs] @dataclass class ModelRegistryAPIClient: """Model registry API.""" config: Configuration
[docs] @classmethod def secure_connection( cls, server_address: str, port: int = 443, *, user_token: str, custom_ca: str | None = None, ) -> ModelRegistryAPIClient: """Constructor. Args: server_address: Server address. port: Server port. Defaults to 443. Keyword Args: user_token: The PEM-encoded user token as a string. custom_ca: The path to a PEM- """ return cls( Configuration( f"{server_address}:{port}", access_token=user_token, ssl_ca_cert=custom_ca, ) )
[docs] @classmethod def insecure_connection( cls, server_address: str, port: int, user_token: str | None = None, ) -> ModelRegistryAPIClient: """Constructor. Args: server_address: Server address. port: Server port. user_token: The PEM-encoded user token as a string. """ return cls( Configuration(host=f"{server_address}:{port}", access_token=user_token) )
[docs] @asynccontextmanager async def get_client(self) -> AsyncIterator[ModelRegistryServiceApi]: """Get a client for the model registry.""" api_client = ApiClient(self.config) client = ModelRegistryServiceApi(api_client) try: yield client finally: await api_client.close()
[docs] async def upsert_registered_model( self, registered_model: RegisteredModel ) -> RegisteredModel: """Upsert a registered model. Updates or creates a registered model on the server. Args: registered_model: Registered model. Returns: New registered model. """ async with self.get_client() as client: if registered_model.id: rm = await client.update_registered_model( registered_model.id, registered_model.update() ) else: rm = await client.create_registered_model(registered_model.create()) return RegisteredModel.from_basemodel(rm)
[docs] async def get_registered_model_by_id(self, id: str) -> RegisteredModel | None: """Fetch a registered model by its ID. Args: id: Registered model ID. Returns: Registered model. """ async with self.get_client() as client: try: rm = await client.get_registered_model(id) except mr_exceptions.NotFoundException: return None return RegisteredModel.from_basemodel(rm)
@overload async def get_registered_model_by_params(self, name: str): ... @overload async def get_registered_model_by_params(self, *, external_id: str): ...
[docs] @required_args(("name",), ("external_id",)) async def get_registered_model_by_params( self, name: str | None = None, external_id: str | None = None ) -> RegisteredModel | None: """Fetch a registered model by its name or external ID. Args: name: Registered model name. external_id: Registered model external ID. Returns: Registered model. """ async with self.get_client() as client: try: rm = await client.find_registered_model( name=name, external_id=external_id ) except mr_exceptions.NotFoundException: return None return RegisteredModel.from_basemodel(rm)
[docs] async def get_registered_models( self, options: ListOptions | None = None ) -> list[RegisteredModel]: """Fetch registered models. Args: options: Options for listing registered models. Returns: Registered models. """ async with self.get_client() as client: rm_list = await client.get_registered_models( **(options or ListOptions()).as_options() ) if options: options.next_page_token = rm_list.next_page_token return [RegisteredModel.from_basemodel(rm) for rm in rm_list.items or []]
[docs] async def upsert_model_version( self, model_version: ModelVersion, registered_model_id: str | None = None ) -> ModelVersion: """Upsert a model version. Updates or creates a model version on the server. Args: model_version: Model version to upsert. registered_model_id: ID of the registered model this version will be associated to. Can be None when updating an existing model version. Returns: New model version. """ async with self.get_client() as client: if model_version.id: mv = await client.update_model_version( model_version.id, model_version.update() ) elif registered_model_id: mv = await client.create_model_version( model_version.create(registered_model_id=registered_model_id) ) else: msg = f"Registered model ID required for creating a new model version: {model_version}" raise ValueError(msg) return ModelVersion.from_basemodel(mv)
[docs] async def get_model_version_by_id( self, model_version_id: str ) -> ModelVersion | None: """Fetch a model version by its ID. Args: model_version_id: Model version ID. Returns: Model version. """ async with self.get_client() as client: try: mv = await client.get_model_version(model_version_id) except mr_exceptions.NotFoundException: return None return ModelVersion.from_basemodel(mv)
[docs] async def get_model_versions( self, registered_model_id: str, options: ListOptions | None = None ) -> list[ModelVersion]: """Fetch model versions by registered model ID. Args: registered_model_id: Registered model ID. options: Options for listing model versions. Returns: Model versions. """ async with self.get_client() as client: mv_list = await client.get_registered_model_versions( registered_model_id, **(options or ListOptions()).as_options() ) if options: options.next_page_token = mv_list.next_page_token return [ModelVersion.from_basemodel(mv) for mv in mv_list.items or []]
@overload async def get_model_version_by_params( self, registered_model_id: str, name: str ): ... @overload async def get_model_version_by_params(self, *, external_id: str): ...
[docs] @required_args( ( "registered_model_id", "name", ), ("external_id",), ) async def get_model_version_by_params( self, registered_model_id: str | None = None, name: str | None = None, external_id: str | None = None, ) -> ModelVersion | None: """Fetch a model version by associated parameters. Either fetches by using external ID or by using registered model ID and version name. Args: registered_model_id: Registered model ID. name: Model version. external_id: Model version external ID. Returns: Model version. """ async with self.get_client() as client: try: mv = await client.find_model_version( name=name, external_id=external_id, parent_resource_id=registered_model_id, ) except mr_exceptions.NotFoundException: return None return ModelVersion.from_basemodel(mv)
[docs] async def upsert_model_artifact( self, model_artifact: ModelArtifact ) -> ModelArtifact: """Upsert a model artifact. Updates or creates a model artifact on the server. Args: model_artifact: Model artifact to upsert. model_version_id: ID of the model version this artifact will be associated to. Returns: New model artifact. """ async with self.get_client() as client: if not model_artifact.id: ma = await client.create_model_artifact(model_artifact.create()) else: ma = await client.update_model_artifact( model_artifact.id, model_artifact.update() ) return ModelArtifact.from_basemodel(ma)
[docs] async def upsert_model_version_artifact( self, artifact: ArtifactT, model_version_id: str ) -> ArtifactT: """Creates a model version artifact. Creates a model version artifact on the server. Args: artifact: Model version artifact to upsert. model_version_id: ID of the model version this artifact will be associated to. Returns: New model version artifact. """ async with self.get_client() as client: return cast( ArtifactT, Artifact.validate_artifact( await client.upsert_model_version_artifact( model_version_id, artifact.wrap() ) ), )
[docs] async def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None: """Fetch a model artifact by its ID. Args: id: Model artifact ID. Returns: Model artifact. """ async with self.get_client() as client: try: ma = await client.get_model_artifact(id) except mr_exceptions.NotFoundException: return None return ModelArtifact.from_basemodel(ma)
@overload async def get_model_artifact_by_params( self, name: str, model_version_id: str, ): ... @overload async def get_model_artifact_by_params(self, *, external_id: str): ...
[docs] @required_args( ( "name", "model_version_id", ), ("external_id",), ) async def get_model_artifact_by_params( self, name: str | None = None, model_version_id: str | None = None, external_id: str | None = None, ) -> ModelArtifact | None: """Fetch a model artifact either by external ID or by its name and the ID of its associated model version. Args: name: Model artifact name. model_version_id: ID of the associated model version. external_id: Model artifact external ID. Returns: Model artifact. """ async with self.get_client() as client: try: ma = await client.find_model_artifact( name=name, parent_resource_id=model_version_id, external_id=external_id, ) except mr_exceptions.NotFoundException: return None return ModelArtifact.from_basemodel(ma)
[docs] async def get_model_artifacts( self, model_version_id: str | None = None, options: ListOptions | None = None, ) -> list[ModelArtifact]: """Fetches model artifacts. Args: model_version_id: ID of the associated model version. options: Options for listing model artifacts. Returns: Model artifacts. """ async with self.get_client() as client: if model_version_id: art_list = await client.get_model_version_artifacts( model_version_id, **(options or ListOptions()).as_options() ) if options: options.next_page_token = art_list.next_page_token models = [] for art in art_list.items or []: converted = Artifact.validate_artifact(art) if isinstance(converted, ModelArtifact): models.append(converted) return models ma_list = await client.get_model_artifacts( **(options or ListOptions()).as_options() ) if options: options.next_page_token = ma_list.next_page_token return [ModelArtifact.from_basemodel(ma) for ma in ma_list.items or []]
[docs] async def get_model_version_artifacts( self, model_version_id: str, options: ListOptions | None = None, ) -> list[Artifact]: """Fetches model artifacts. Args: model_version_id: ID of the associated model version. options: Options for listing model artifacts. Returns: Model artifacts. """ async with self.get_client() as client: art_list = await client.get_model_version_artifacts( model_version_id, **(options or ListOptions()).as_options() ) if options: options.next_page_token = art_list.next_page_token return [Artifact.validate_artifact(art) for art in art_list.items or []]