Source code for model_registry.types.pager

"""Pager for iterating over items."""

from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator, Awaitable, Iterator
from dataclasses import dataclass, field
from typing import Callable, Generic, TypeVar, cast

from .base import BaseModel
from .options import ListOptions, OrderByField

T = TypeVar("T", bound=BaseModel)


[docs] @dataclass class Pager(Generic[T], Iterator[T], AsyncIterator[T]): """Pager for iterating over items. Assumes that page_fn is a paged function that takes ListOptions and returns a list of items. """ page_fn: ( Callable[[ListOptions], list[T]] | Callable[[ListOptions], Awaitable[list[T]]] ) options: ListOptions = field(default_factory=ListOptions) def __post_init__(self): self.restart() if asyncio.iscoroutinefunction(self.page_fn): self.__next__ = NotImplemented self.next_page = self._anext_page self.next_item = self._anext_item else: self.__anext__ = NotImplemented self.next_page = self._next_page self.next_item = self._next_item
[docs] def restart(self) -> Pager[T]: """Reset the pager. This keeps the current options and page function, but resets the internal state. """ # as MLMD loops over pages, we need to keep track of the first page or we'll loop forever self._start: str | None = None self._current_page: list[T] | None = None # tracks the next item on the current page self._i = 0 self.options.next_page_token = None return self
[docs] def order_by_creation_time(self) -> Pager[T]: """Order items by creation time. This resets the pager. """ self.options.order_by = OrderByField.CREATE_TIME return self.restart()
[docs] def order_by_update_time(self) -> Pager[T]: """Order items by update time. This resets the pager. """ self.options.order_by = OrderByField.LAST_UPDATE_TIME return self.restart()
[docs] def order_by_id(self) -> Pager[T]: """Order items by ID. This resets the pager. """ self.options.order_by = OrderByField.ID return self.restart()
[docs] def page_size(self, n: int) -> Pager[T]: """Set the page size for each request. This resets the pager. """ if n < 1: msg = f"Page size must be at least 1, got {n}" raise ValueError(msg) self.options.limit = n return self.restart()
[docs] def ascending(self) -> Pager[T]: """Order items in ascending order. This resets the pager. """ self.options.is_asc = True return self.restart()
[docs] def descending(self) -> Pager[T]: """Order items in descending order. This resets the pager. """ self.options.is_asc = False return self.restart()
def _next_page(self) -> list[T]: """Get the next page of items. This will automatically loop over pages. """ return cast(list[T], self.page_fn(self.options)) async def _anext_page(self) -> list[T]: """Get the next page of items. This will automatically loop over pages. """ return await cast(Awaitable[list[T]], self.page_fn(self.options)) def _needs_fetch(self) -> bool: return not self._current_page or ( self._i >= len(self._current_page) and self._start is not None ) def _next_item(self) -> T: """Get the next item in the pager. This variant won't check for looping, so it's useful for manual iteration/scripting. NOTE: This won't check for looping, so use with caution. If you want to check for looping, use the pythonic `next()`. """ if self._needs_fetch(): self._current_page = self._next_page() self._i = 0 assert self._current_page if self._i >= len(self._current_page): raise StopIteration item = self._current_page[self._i] self._i += 1 return item async def _anext_item(self) -> T: """Get the next item in the pager. This variant won't check for looping, so it's useful for manual iteration/scripting. NOTE: This won't check for looping, so use with caution. If you want to check for looping, use the pythonic `next()`. """ if self._needs_fetch(): self._current_page = await self._anext_page() self._i = 0 assert self._current_page if self._i >= len(self._current_page): raise StopIteration item = self._current_page[self._i] self._i += 1 return item def __next__(self) -> T: check_looping = self._needs_fetch() item = self._next_item() if self._start is None: self._start = self.options.next_page_token elif check_looping and self.options.next_page_token == self._start: raise StopIteration return item async def __anext__(self) -> T: check_looping = self._needs_fetch() item = await self._anext_item() if self._start is None: self._start = self.options.next_page_token elif check_looping and self.options.next_page_token == self._start: raise StopAsyncIteration return item def __iter__(self) -> Iterator[T]: return self def __aiter__(self) -> AsyncIterator[T]: return self