Skip to content

sklab.experiment

sklab.experiment.Experiment dataclass

Bundle experiment inputs for an sklearn-style run.

Source code in src/sklab/experiment.py
@dataclass(slots=True)
class Experiment:
    """Bundle experiment inputs for an sklearn-style run."""

    pipeline: Any
    logger: LoggerProtocol = field(default_factory=NoOpLogger)
    scoring: Scoring | Sequence[Scoring] | None = None
    name: str | None = None
    tags: Mapping[str, str] | None = None
    _fitted_estimator: Any | None = None

    def fit(
        self,
        X: Any,
        y: Any | None = None,
        *,
        params: Mapping[str, Any] | None = None,
        run_name: str | None = None,
    ) -> FitResult:
        """Fit the pipeline on the provided data and log the run."""
        estimator = clone(self.pipeline)
        merged_params = _merge_params(estimator, params)
        if params:
            estimator.set_params(**params)
        with self.logger.start_run(
            name=run_name or self.name,
            config=merged_params,
            tags=self.tags,
        ) as run:
            estimator.fit(X, y)
            run.log_model(estimator, name="model")
        self._fitted_estimator = estimator
        return FitResult(
            estimator=estimator, metrics={}, params=merged_params, raw=estimator
        )

    def evaluate(
        self,
        X: Any,
        y: Any | None = None,
        *,
        run_name: str | None = None,
    ) -> EvalResult:
        """Evaluate the fitted estimator using experiment scoring and log metrics."""
        check_is_fitted(self._fitted_estimator)
        scoring = _require_scoring(self.scoring)
        metrics = _score_estimator(self._fitted_estimator, X, y, scoring)
        with self.logger.start_run(
            name=run_name or self.name,
            config=None,
            tags=self.tags,
        ) as run:
            run.log_metrics(metrics)
        return EvalResult(metrics=metrics, raw=metrics)

    def cross_validate(
        self,
        X: Any,
        y: Any | None = None,
        *,
        cv: Any,
        refit: bool = True,
        run_name: str | None = None,
    ) -> CVResult:
        """Run sklearn cross-validation, aggregate metrics, and optionally refit."""
        scoring_dict = _require_scoring(self.scoring)
        scoring = _sklearn_scoring(scoring_dict)
        scores = sklearn_cross_validate(
            self.pipeline,
            X,
            y,
            scoring=scoring,
            cv=cv,
            return_train_score=False,
        )
        fold_metrics = {name: list(scores[f"test_{name}"]) for name in scoring.keys()}
        metrics = _aggregate_cv_metrics(fold_metrics)
        final_estimator = None
        if refit:
            final_estimator = clone(self.pipeline)
            final_estimator.fit(X, y)
        with self.logger.start_run(
            name=run_name or self.name,
            config=None,
            tags=self.tags,
        ) as run:
            run.log_metrics(metrics)
            if final_estimator is not None:
                run.log_model(final_estimator, name="model")
        if final_estimator is not None:
            self._fitted_estimator = final_estimator
        return CVResult(
            metrics=metrics,
            fold_metrics=fold_metrics,
            estimator=final_estimator,
            raw=scores,
        )

    @overload
    def search(
        self,
        search: OptunaConfig | OptunaSearcher,
        X: Any,
        y: Any | None = None,
        *,
        cv: Any | None = None,
        n_trials: int | None = None,
        timeout: float | None = None,
        run_name: str | None = None,
    ) -> SearchResult[Study]: ...

    @overload
    def search(
        self,
        search: GridSearchConfig | GridSearchCV,
        X: Any,
        y: Any | None = None,
        *,
        cv: Any | None = None,
        n_trials: int | None = None,
        timeout: float | None = None,
        run_name: str | None = None,
    ) -> SearchResult[GridSearchCV]: ...

    @overload
    def search(
        self,
        search: RandomSearchConfig | RandomizedSearchCV,
        X: Any,
        y: Any | None = None,
        *,
        cv: Any | None = None,
        n_trials: int | None = None,
        timeout: float | None = None,
        run_name: str | None = None,
    ) -> SearchResult[RandomizedSearchCV]: ...

    @overload
    def search(
        self,
        search: SearcherProtocol | SearchConfigProtocol,
        X: Any,
        y: Any | None = None,
        *,
        cv: Any | None = None,
        n_trials: int | None = None,
        timeout: float | None = None,
        run_name: str | None = None,
    ) -> SearchResult[Any]: ...

    def search(
        self,
        search: SearcherProtocol | SearchConfigProtocol,
        X: Any,
        y: Any | None = None,
        *,
        cv: Any | None = None,
        n_trials: int | None = None,
        timeout: float | None = None,
        run_name: str | None = None,
    ) -> SearchResult[Any]:
        """Run a hyperparameter search using a searcher or config object."""
        searcher = _build_searcher(
            search,
            pipeline=self.pipeline,
            scoring=self.scoring,
            cv=cv,
            n_trials=n_trials,
            timeout=timeout,
        )
        with self.logger.start_run(
            name=run_name or self.name,
            config=None,
            tags=self.tags,
        ) as run:
            searcher.fit(X, y)
            best_params = getattr(searcher, "best_params_", {})
            best_score = getattr(searcher, "best_score_", None)
            run.log_params(best_params)
            if best_score is not None:
                run.log_metrics({"best_score": float(best_score)})
            best_estimator = getattr(searcher, "best_estimator_", None)
            if best_estimator is not None:
                run.log_model(best_estimator, name="model")
        if best_estimator is not None:
            self._fitted_estimator = best_estimator
        # Expose Study for Optuna searches, searcher for sklearn searches
        raw = searcher.study if isinstance(search, (OptunaConfig, OptunaSearcher)) else searcher
        return SearchResult(
            best_params=best_params,
            best_score=best_score,
            estimator=best_estimator,
            raw=raw,
        )

fit(X, y=None, *, params=None, run_name=None)

Fit the pipeline on the provided data and log the run.

Source code in src/sklab/experiment.py
def fit(
    self,
    X: Any,
    y: Any | None = None,
    *,
    params: Mapping[str, Any] | None = None,
    run_name: str | None = None,
) -> FitResult:
    """Fit the pipeline on the provided data and log the run."""
    estimator = clone(self.pipeline)
    merged_params = _merge_params(estimator, params)
    if params:
        estimator.set_params(**params)
    with self.logger.start_run(
        name=run_name or self.name,
        config=merged_params,
        tags=self.tags,
    ) as run:
        estimator.fit(X, y)
        run.log_model(estimator, name="model")
    self._fitted_estimator = estimator
    return FitResult(
        estimator=estimator, metrics={}, params=merged_params, raw=estimator
    )

evaluate(X, y=None, *, run_name=None)

Evaluate the fitted estimator using experiment scoring and log metrics.

Source code in src/sklab/experiment.py
def evaluate(
    self,
    X: Any,
    y: Any | None = None,
    *,
    run_name: str | None = None,
) -> EvalResult:
    """Evaluate the fitted estimator using experiment scoring and log metrics."""
    check_is_fitted(self._fitted_estimator)
    scoring = _require_scoring(self.scoring)
    metrics = _score_estimator(self._fitted_estimator, X, y, scoring)
    with self.logger.start_run(
        name=run_name or self.name,
        config=None,
        tags=self.tags,
    ) as run:
        run.log_metrics(metrics)
    return EvalResult(metrics=metrics, raw=metrics)

cross_validate(X, y=None, *, cv, refit=True, run_name=None)

Run sklearn cross-validation, aggregate metrics, and optionally refit.

Source code in src/sklab/experiment.py
def cross_validate(
    self,
    X: Any,
    y: Any | None = None,
    *,
    cv: Any,
    refit: bool = True,
    run_name: str | None = None,
) -> CVResult:
    """Run sklearn cross-validation, aggregate metrics, and optionally refit."""
    scoring_dict = _require_scoring(self.scoring)
    scoring = _sklearn_scoring(scoring_dict)
    scores = sklearn_cross_validate(
        self.pipeline,
        X,
        y,
        scoring=scoring,
        cv=cv,
        return_train_score=False,
    )
    fold_metrics = {name: list(scores[f"test_{name}"]) for name in scoring.keys()}
    metrics = _aggregate_cv_metrics(fold_metrics)
    final_estimator = None
    if refit:
        final_estimator = clone(self.pipeline)
        final_estimator.fit(X, y)
    with self.logger.start_run(
        name=run_name or self.name,
        config=None,
        tags=self.tags,
    ) as run:
        run.log_metrics(metrics)
        if final_estimator is not None:
            run.log_model(final_estimator, name="model")
    if final_estimator is not None:
        self._fitted_estimator = final_estimator
    return CVResult(
        metrics=metrics,
        fold_metrics=fold_metrics,
        estimator=final_estimator,
        raw=scores,
    )

search(search, X, y=None, *, cv=None, n_trials=None, timeout=None, run_name=None)

search(
    search: OptunaConfig | OptunaSearcher,
    X: Any,
    y: Any | None = None,
    *,
    cv: Any | None = None,
    n_trials: int | None = None,
    timeout: float | None = None,
    run_name: str | None = None,
) -> SearchResult[Study]
search(
    search: GridSearchConfig | GridSearchCV,
    X: Any,
    y: Any | None = None,
    *,
    cv: Any | None = None,
    n_trials: int | None = None,
    timeout: float | None = None,
    run_name: str | None = None,
) -> SearchResult[GridSearchCV]
search(
    search: RandomSearchConfig | RandomizedSearchCV,
    X: Any,
    y: Any | None = None,
    *,
    cv: Any | None = None,
    n_trials: int | None = None,
    timeout: float | None = None,
    run_name: str | None = None,
) -> SearchResult[RandomizedSearchCV]
search(
    search: SearcherProtocol | SearchConfigProtocol,
    X: Any,
    y: Any | None = None,
    *,
    cv: Any | None = None,
    n_trials: int | None = None,
    timeout: float | None = None,
    run_name: str | None = None,
) -> SearchResult[Any]

Run a hyperparameter search using a searcher or config object.

Source code in src/sklab/experiment.py
def search(
    self,
    search: SearcherProtocol | SearchConfigProtocol,
    X: Any,
    y: Any | None = None,
    *,
    cv: Any | None = None,
    n_trials: int | None = None,
    timeout: float | None = None,
    run_name: str | None = None,
) -> SearchResult[Any]:
    """Run a hyperparameter search using a searcher or config object."""
    searcher = _build_searcher(
        search,
        pipeline=self.pipeline,
        scoring=self.scoring,
        cv=cv,
        n_trials=n_trials,
        timeout=timeout,
    )
    with self.logger.start_run(
        name=run_name or self.name,
        config=None,
        tags=self.tags,
    ) as run:
        searcher.fit(X, y)
        best_params = getattr(searcher, "best_params_", {})
        best_score = getattr(searcher, "best_score_", None)
        run.log_params(best_params)
        if best_score is not None:
            run.log_metrics({"best_score": float(best_score)})
        best_estimator = getattr(searcher, "best_estimator_", None)
        if best_estimator is not None:
            run.log_model(best_estimator, name="model")
    if best_estimator is not None:
        self._fitted_estimator = best_estimator
    # Expose Study for Optuna searches, searcher for sklearn searches
    raw = searcher.study if isinstance(search, (OptunaConfig, OptunaSearcher)) else searcher
    return SearchResult(
        best_params=best_params,
        best_score=best_score,
        estimator=best_estimator,
        raw=raw,
    )

__init__(pipeline, logger=NoOpLogger(), scoring=None, name=None, tags=None, _fitted_estimator=None)

sklab.experiment.FitResult dataclass

Result of a single fit run.

Attributes:

Name Type Description
estimator Any

The fitted pipeline/estimator.

metrics Mapping[str, float]

Empty dict (fit doesn't compute metrics).

params Mapping[str, Any]

Merged parameters used for fitting.

raw Any

The fitted estimator (same as estimator, for API consistency).

Source code in src/sklab/_results.py
@dataclass(slots=True)
class FitResult:
    """Result of a single fit run.

    Attributes:
        estimator: The fitted pipeline/estimator.
        metrics: Empty dict (fit doesn't compute metrics).
        params: Merged parameters used for fitting.
        raw: The fitted estimator (same as estimator, for API consistency).
    """

    estimator: Any
    metrics: Mapping[str, float]
    params: Mapping[str, Any]
    raw: Any

sklab.experiment.EvalResult dataclass

Result of evaluating a fitted estimator on a dataset.

Attributes:

Name Type Description
metrics Mapping[str, float]

Computed metric scores.

raw Mapping[str, float]

The metrics dict (same as metrics, for API consistency).

Source code in src/sklab/_results.py
@dataclass(slots=True)
class EvalResult:
    """Result of evaluating a fitted estimator on a dataset.

    Attributes:
        metrics: Computed metric scores.
        raw: The metrics dict (same as metrics, for API consistency).
    """

    metrics: Mapping[str, float]
    raw: Mapping[str, float]

sklab.experiment.CVResult dataclass

Result of a cross-validation run.

Attributes:

Name Type Description
metrics Mapping[str, float]

Aggregated metrics (mean/std across folds).

fold_metrics Mapping[str, list[float]]

Per-fold metric values.

estimator Any | None

Final refitted estimator (if refit=True), else None.

raw Mapping[str, Any]

Full sklearn cross_validate() dict, including fit_time, score_time, and test scores for each fold.

Source code in src/sklab/_results.py
@dataclass(slots=True)
class CVResult:
    """Result of a cross-validation run.

    Attributes:
        metrics: Aggregated metrics (mean/std across folds).
        fold_metrics: Per-fold metric values.
        estimator: Final refitted estimator (if refit=True), else None.
        raw: Full sklearn cross_validate() dict, including fit_time,
            score_time, and test scores for each fold.
    """

    metrics: Mapping[str, float]
    fold_metrics: Mapping[str, list[float]]
    estimator: Any | None
    raw: Mapping[str, Any]

sklab.experiment.SearchResult dataclass

Bases: Generic[RawT]

Result of a hyperparameter search run.

Attributes:

Name Type Description
best_params Mapping[str, Any]

Best hyperparameters found.

best_score float | None

Best cross-validation score achieved.

estimator Any | None

Best estimator refitted on full data (if refit=True).

raw RawT

The underlying search object. For OptunaConfig, this is the Optuna Study with full trial history. For sklearn searchers (GridSearchCV, RandomizedSearchCV), this is the fitted searcher with cv_results_ and other attributes.

Source code in src/sklab/_results.py
@dataclass(slots=True)
class SearchResult(Generic[RawT]):
    """Result of a hyperparameter search run.

    Attributes:
        best_params: Best hyperparameters found.
        best_score: Best cross-validation score achieved.
        estimator: Best estimator refitted on full data (if refit=True).
        raw: The underlying search object. For OptunaConfig, this is the
            Optuna Study with full trial history. For sklearn searchers
            (GridSearchCV, RandomizedSearchCV), this is the fitted searcher
            with cv_results_ and other attributes.
    """

    best_params: Mapping[str, Any]
    best_score: float | None
    estimator: Any | None
    raw: RawT