Skip to content

sklab.search

sklab.search.GridSearchConfig dataclass

Quick config for sklearn GridSearchCV.

Source code in src/sklab/_search/sklearn.py
@dataclass(slots=True)
class GridSearchConfig:
    """Quick config for sklearn GridSearchCV."""

    param_grid: Mapping[str, Any]
    scoring: Scoring | Sequence[Scoring] | None = None
    cv: Any | None = None
    refit: bool | str = True
    n_jobs: int | None = None
    verbose: int = 0
    pre_dispatch: str | int | None = "2*n_jobs"
    error_score: float | str = "raise"

    def create_searcher(
        self,
        *,
        pipeline: Any,
        scoring: Scoring | Sequence[Scoring] | None,
        cv: Any | None,
        n_trials: int | None,
        timeout: float | None,
    ) -> GridSearchCV:
        resolved = _resolve_scoring(self.scoring, scoring)
        return GridSearchCV(
            pipeline,
            param_grid=self.param_grid,
            scoring=resolved,
            cv=self.cv if self.cv is not None else cv,
            refit=self.refit,
            n_jobs=self.n_jobs,
            verbose=self.verbose,
            pre_dispatch=self.pre_dispatch,
            error_score=self.error_score,
        )

create_searcher(*, pipeline, scoring, cv, n_trials, timeout)

Source code in src/sklab/_search/sklearn.py
def create_searcher(
    self,
    *,
    pipeline: Any,
    scoring: Scoring | Sequence[Scoring] | None,
    cv: Any | None,
    n_trials: int | None,
    timeout: float | None,
) -> GridSearchCV:
    resolved = _resolve_scoring(self.scoring, scoring)
    return GridSearchCV(
        pipeline,
        param_grid=self.param_grid,
        scoring=resolved,
        cv=self.cv if self.cv is not None else cv,
        refit=self.refit,
        n_jobs=self.n_jobs,
        verbose=self.verbose,
        pre_dispatch=self.pre_dispatch,
        error_score=self.error_score,
    )

__init__(param_grid, scoring=None, cv=None, refit=True, n_jobs=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise')

sklab.search.RandomSearchConfig dataclass

Quick config for sklearn RandomizedSearchCV.

Source code in src/sklab/_search/sklearn.py
@dataclass(slots=True)
class RandomSearchConfig:
    """Quick config for sklearn RandomizedSearchCV."""

    param_distributions: Mapping[str, Any]
    n_iter: int | None = None
    scoring: Scoring | Sequence[Scoring] | None = None
    cv: Any | None = None
    refit: bool | str = True
    n_jobs: int | None = None
    random_state: int | None = None
    verbose: int = 0
    pre_dispatch: str | int | None = "2*n_jobs"
    error_score: float | str = "raise"

    def create_searcher(
        self,
        *,
        pipeline: Any,
        scoring: Scoring | Sequence[Scoring] | None,
        cv: Any | None,
        n_trials: int | None,
        timeout: float | None,
    ) -> RandomizedSearchCV:
        resolved = _resolve_scoring(self.scoring, scoring)
        resolved_n_iter = self.n_iter or n_trials or 20
        return RandomizedSearchCV(
            pipeline,
            param_distributions=self.param_distributions,
            n_iter=resolved_n_iter,
            scoring=resolved,
            cv=self.cv if self.cv is not None else cv,
            refit=self.refit,
            n_jobs=self.n_jobs,
            random_state=self.random_state,
            verbose=self.verbose,
            pre_dispatch=self.pre_dispatch,
            error_score=self.error_score,
        )

create_searcher(*, pipeline, scoring, cv, n_trials, timeout)

Source code in src/sklab/_search/sklearn.py
def create_searcher(
    self,
    *,
    pipeline: Any,
    scoring: Scoring | Sequence[Scoring] | None,
    cv: Any | None,
    n_trials: int | None,
    timeout: float | None,
) -> RandomizedSearchCV:
    resolved = _resolve_scoring(self.scoring, scoring)
    resolved_n_iter = self.n_iter or n_trials or 20
    return RandomizedSearchCV(
        pipeline,
        param_distributions=self.param_distributions,
        n_iter=resolved_n_iter,
        scoring=resolved,
        cv=self.cv if self.cv is not None else cv,
        refit=self.refit,
        n_jobs=self.n_jobs,
        random_state=self.random_state,
        verbose=self.verbose,
        pre_dispatch=self.pre_dispatch,
        error_score=self.error_score,
    )

__init__(param_distributions, n_iter=None, scoring=None, cv=None, refit=True, n_jobs=None, random_state=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise')

sklab.search.OptunaConfig dataclass

Configuration for Optuna-based hyperparameter search.

Use this to configure how Experiment.search() explores the hyperparameter space using Optuna's optimization algorithms.

Parameters

search_space A callable that defines the hyperparameter search space. Receives an Optuna Trial_ object and returns a mapping of parameter names to suggested values. Use trial.suggest_* methods to sample values.

Example::

    def search_space(trial: Trial) -> dict[str, Any]:
        return {
            "classifier__C": trial.suggest_float("C", 0.01, 100, log=True),
            "classifier__kernel": trial.suggest_categorical("kernel", ["rbf", "linear"]),
        }

n_trials Number of trials to run. Each trial evaluates one hyperparameter configuration. Default: 50.

direction Optimization direction: Direction.MAXIMIZE (default) or Direction.MINIMIZE. Since Direction is a StrEnum, you can also pass "maximize" or "minimize" directly. Use maximize for metrics like accuracy; minimize for metrics like log_loss.

callbacks Optional sequence of callbacks invoked after each trial completes. Each callback receives the Study_ and FrozenTrial_ objects. Useful for early stopping, logging, or custom pruning logic. See Optuna callbacks tutorial_.

study_factory Optional factory function to create a custom Study_. Receives direction as a keyword argument and returns a Study. Use this when you need:

- A custom `sampler`_ (e.g., ``RandomSampler``, ``CmaEsSampler``)
- A custom `pruner`_ (e.g., ``HyperbandPruner``)
- Persistent storage (database URL for resumable studies)
- A named study for tracking across runs

Example::

    def my_study_factory(direction: str) -> Study:
        return optuna.create_study(
            direction=direction,
            sampler=optuna.samplers.TPESampler(seed=42),
            pruner=optuna.pruners.HyperbandPruner(),
            storage="sqlite:///optuna.db",
            study_name="my-experiment",
            load_if_exists=True,
        )

If None, uses ``optuna.create_study(direction=direction)`` with
defaults (TPE sampler, median pruner, in-memory storage).

scoring Scorer to use for evaluating trials. If None, uses the first scorer from the Experiment's scoring. Can be a string (e.g., "accuracy"), a ScorerName enum, or a callable.

References

.. _Trial: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html .. _Study: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html .. _FrozenTrial: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.FrozenTrial.html .. _sampler: https://optuna.readthedocs.io/en/stable/reference/samplers/index.html .. _pruner: https://optuna.readthedocs.io/en/stable/reference/pruners.html .. _Optuna callbacks tutorial: https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html

Source code in src/sklab/_search/optuna.py
@dataclass(slots=True)
class OptunaConfig:
    """Configuration for Optuna-based hyperparameter search.

    Use this to configure how ``Experiment.search()`` explores the hyperparameter
    space using Optuna's optimization algorithms.

    Parameters
    ----------
    search_space
        A callable that defines the hyperparameter search space. Receives an
        Optuna `Trial`_ object and returns a mapping of parameter names to
        suggested values. Use ``trial.suggest_*`` methods to sample values.

        Example::

            def search_space(trial: Trial) -> dict[str, Any]:
                return {
                    "classifier__C": trial.suggest_float("C", 0.01, 100, log=True),
                    "classifier__kernel": trial.suggest_categorical("kernel", ["rbf", "linear"]),
                }

    n_trials
        Number of trials to run. Each trial evaluates one hyperparameter
        configuration. Default: 50.

    direction
        Optimization direction: ``Direction.MAXIMIZE`` (default) or
        ``Direction.MINIMIZE``. Since ``Direction`` is a ``StrEnum``, you can
        also pass ``"maximize"`` or ``"minimize"`` directly. Use maximize for
        metrics like accuracy; minimize for metrics like log_loss.

    callbacks
        Optional sequence of callbacks invoked after each trial completes.
        Each callback receives the `Study`_ and `FrozenTrial`_ objects.
        Useful for early stopping, logging, or custom pruning logic.
        See `Optuna callbacks tutorial`_.

    study_factory
        Optional factory function to create a custom `Study`_. Receives
        ``direction`` as a keyword argument and returns a Study. Use this
        when you need:

        - A custom `sampler`_ (e.g., ``RandomSampler``, ``CmaEsSampler``)
        - A custom `pruner`_ (e.g., ``HyperbandPruner``)
        - Persistent storage (database URL for resumable studies)
        - A named study for tracking across runs

        Example::

            def my_study_factory(direction: str) -> Study:
                return optuna.create_study(
                    direction=direction,
                    sampler=optuna.samplers.TPESampler(seed=42),
                    pruner=optuna.pruners.HyperbandPruner(),
                    storage="sqlite:///optuna.db",
                    study_name="my-experiment",
                    load_if_exists=True,
                )

        If None, uses ``optuna.create_study(direction=direction)`` with
        defaults (TPE sampler, median pruner, in-memory storage).

    scoring
        Scorer to use for evaluating trials. If None, uses the first scorer
        from the Experiment's scoring. Can be a string (e.g., ``"accuracy"``),
        a ScorerName enum, or a callable.

    References
    ----------
    .. _Trial: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html
    .. _Study: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html
    .. _FrozenTrial: https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.FrozenTrial.html
    .. _sampler: https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
    .. _pruner: https://optuna.readthedocs.io/en/stable/reference/pruners.html
    .. _Optuna callbacks tutorial: https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/007_optuna_callback.html
    """

    search_space: Callable[[Trial], Mapping[str, Any]]
    n_trials: int = 50
    direction: Direction = Direction.MAXIMIZE
    callbacks: Sequence[Callable[[Study, FrozenTrial], None]] | None = None
    study_factory: Callable[..., Study] | None = None
    scoring: Scoring | None = None

    def create_searcher(
        self,
        *,
        pipeline: Any,
        scoring: Scoring | Sequence[Scoring] | None,
        cv: Any | None,
        n_trials: int | None,
        timeout: float | None,
    ) -> OptunaSearcher:
        return OptunaSearcher(
            pipeline=pipeline,
            experiment_scoring=scoring,
            cv=cv,
            n_trials=n_trials or self.n_trials,
            timeout=timeout,
            search_space=self.search_space,
            direction=self.direction,
            callbacks=self.callbacks,
            study_factory=self.study_factory,
            config_scoring=self.scoring,
        )

create_searcher(*, pipeline, scoring, cv, n_trials, timeout)

Source code in src/sklab/_search/optuna.py
def create_searcher(
    self,
    *,
    pipeline: Any,
    scoring: Scoring | Sequence[Scoring] | None,
    cv: Any | None,
    n_trials: int | None,
    timeout: float | None,
) -> OptunaSearcher:
    return OptunaSearcher(
        pipeline=pipeline,
        experiment_scoring=scoring,
        cv=cv,
        n_trials=n_trials or self.n_trials,
        timeout=timeout,
        search_space=self.search_space,
        direction=self.direction,
        callbacks=self.callbacks,
        study_factory=self.study_factory,
        config_scoring=self.scoring,
    )

__init__(search_space, n_trials=50, direction=Direction.MAXIMIZE, callbacks=None, study_factory=None, scoring=None)

sklab.search.SearcherProtocol

Bases: Protocol

Minimal interface required by Experiment.search.

Source code in src/sklab/adapters/search.py
@runtime_checkable
class SearcherProtocol(Protocol):
    """Minimal interface required by Experiment.search."""

    def fit(self, X: Any, y: Any | None = None) -> Any:  # noqa: N803
        ...

    best_params_: Mapping[str, Any] | None
    best_score_: float | None
    best_estimator_: Any | None

sklab.search.SearchConfigProtocol

Bases: Protocol

Config that can build a searcher for Experiment.search.

Source code in src/sklab/adapters/search.py
@runtime_checkable
class SearchConfigProtocol(Protocol):
    """Config that can build a searcher for Experiment.search."""

    def create_searcher(
        self,
        *,
        pipeline: Any,
        scoring: Scoring | Sequence[Scoring] | None,
        cv: Any | None,
        n_trials: int | None,
        timeout: float | None,
    ) -> SearcherProtocol: ...