Source code for pyod.utils.persistence

# -*- coding: utf-8 -*-
"""Cross-sklearn-version model persistence for PyOD.

This module is the recommended way to save and load PyOD detectors.
It wraps `joblib` with two capabilities the raw `joblib.dump` /
`joblib.load` path does not provide:

1. A versioned envelope written by `save()`. The envelope records the
   PyOD, sklearn, numpy, scipy, joblib, and Python versions in effect
   at save time. `load()` compares the envelope against the running
   environment and emits a `UserWarning` when any binary-format
   dependency drifts; `load(..., strict=True)` raises instead. This
   lets users detect dependency drift before it surprises them in
   production.
2. A `compat_load()` helper that loads legacy artifacts whose sklearn
   `Tree` node dtype no longer matches the running sklearn (a recurring
   user pain documented in issue #519). `compat_load` uses joblib's
   own unpickler with the BUILD-opcode dispatch entry patched so that
   sklearn `Tree` state is realigned to the running dtype before
   `sklearn.tree._tree.Tree.__setstate__` sees it.

`load()` automatically falls through to `compat_load()` when the
underlying `joblib.load` raises the specific sklearn dtype `ValueError`,
so users who only call `load()` get the rescue path transparently.

WARNING: pickle and joblib load arbitrary Python code. Load only from
trusted sources. The compat_load helper does not change this security
model.

See `docs/model_persistence.rst` for the user-facing guide.
"""
# Author: Yue Zhao <yzhao062@gmail.com>
# License: BSD 2 clause

from __future__ import annotations

import pickle
import platform
import sys
import warnings
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import joblib
import numpy as np
from joblib.numpy_pickle import NumpyUnpickler


__all__ = ["save", "load", "compat_load"]


# ----------------------------------------------------------------------
# Module constants
# ----------------------------------------------------------------------

_CURRENT_PERSISTENCE_VERSION = 1
"""Newest envelope schema version this module writes and reads."""

# Conservative, allowlist-driven dtype realignment for sklearn Tree
# node arrays. Adding a new entry is a deliberate maintenance act:
# it pairs with a regression test, a CHANGES.txt note, and a
# documentation update.
_TREE_NODE_FIELD_DEFAULTS: dict[str, Any] = {
    "missing_go_to_left": 0,
}
"""Tree-node fields the loader may zero-fill when missing from the saved
dtype. Defaults must match what sklearn's pre-existing behavior implied
for legacy models. `missing_go_to_left=0` mirrors the
"don't route on missingness" behavior of pre-1.3 sklearn."""

_TREE_NODE_FIELD_RENAMES: dict[str, str] = {}
"""Tree-node fields the loader may map from an old name to a new name.
Empty in v1 because sklearn has not renamed any tree fields
historically; populated only when an upstream rename is observed."""

_DTYPE_MISMATCH_PREFIX = "node array from the pickle has an incompatible dtype"
"""Exact-prefix string that triggers `load()`'s auto-fall-through to
`compat_load`. If sklearn changes the error text, fall-through stops
firing and the original error propagates — which is the safe failure
mode because it preserves diagnostic context."""

# Version drift checks performed by `load()`. Each row: envelope key,
# callable returning the running value, severity. `warn` entries emit
# UserWarning when drift is detected and escalate to ValueError under
# strict mode. `info` entries never raise; they are recorded for
# diagnostics only.
_VERSION_CHECKS: list[tuple[str, Any, str]] = [
    ("sklearn_version", lambda: _running_version("sklearn"), "warn"),
    ("joblib_version", lambda: joblib.__version__, "warn"),
    ("numpy_version", lambda: np.__version__, "warn"),
    ("scipy_version", lambda: _running_version("scipy"), "warn"),
    ("python_version", lambda: platform.python_version(), "info"),
]


def _running_version(package_name: str) -> str:
    """Resolve the running version of an optional dependency."""
    if package_name == "sklearn":
        import sklearn
        return sklearn.__version__
    if package_name == "scipy":
        import scipy
        return scipy.__version__
    raise KeyError(package_name)


# ----------------------------------------------------------------------
# save
# ----------------------------------------------------------------------

[docs] def save(model: Any, path: Any, metadata: dict | None = None) -> None: """Save a fitted PyOD detector with a versioned envelope. The envelope records every dependency version that can affect pickle/joblib layout, plus a save timestamp and a user-supplied metadata dict. The actual model object is written via ``joblib.dump``; the only difference from raw ``joblib.dump(clf, path)`` is that the model is wrapped in a header dict the matching ``load()`` recognizes. Parameters ---------- model : Any The fitted detector to save. Anything picklable will work; PyOD BaseDetector subclasses are the typical case. path : str or pathlib.Path Destination file path. metadata : dict or None Optional user-supplied metadata (training dataset id, feature schema hash, run id, anything). No schema is imposed; the dict round-trips as-is. Returns ------- None Notes ----- Loading the file with raw ``joblib.load`` returns the envelope dict, not the model. Use ``load()`` from this module to unwrap. """ import pyod envelope = { "_pyod_persistence_version": _CURRENT_PERSISTENCE_VERSION, "pyod_version": pyod.__version__, "sklearn_version": _running_version("sklearn"), "numpy_version": np.__version__, "scipy_version": _running_version("scipy"), "joblib_version": joblib.__version__, "python_version": platform.python_version(), "saved_at": datetime.now(timezone.utc).strftime( "%Y-%m-%dT%H:%M:%SZ"), "model_class": f"{type(model).__module__}." f"{type(model).__name__}", "metadata": metadata, "model": model, } joblib.dump(envelope, path)
# ---------------------------------------------------------------------- # load # ----------------------------------------------------------------------
[docs] def load( path: Any, strict: bool = False, return_metadata: bool = False) -> Any: """Load a PyOD detector saved by `save()` or by raw joblib.dump. `load()` understands three input shapes: 1. An envelope dict written by `save()`. The envelope's recorded dependency versions are compared against the running environment. Drift in sklearn, joblib, numpy, or scipy emits a `UserWarning`; `strict=True` raises `ValueError` instead. 2. A raw detector object written by `joblib.dump(clf, path)` on a previous PyOD release. Returned as-is when `strict=False`; raises under `strict=True` because legacy artifacts have no envelope to verify. 3. A file that fails the initial `joblib.load` with the sklearn `Tree` node dtype error. `load()` falls through to `compat_load(path)` and routes the recovered object through the same envelope/legacy handler. See module docstring. Parameters ---------- path : str or pathlib.Path Path to the artifact. strict : bool, default False When True, version drift in any `warn`-severity dependency raises `ValueError`. `info`-severity drift (Python version) never raises. Legacy artifacts without an envelope also raise under strict mode. return_metadata : bool, default False When True, return ``(model, envelope_without_model_field)`` instead of just the model. For legacy artifacts the second element is ``None``. Returns ------- model : Any The unpickled model. When `return_metadata=True`, returns ``(model, envelope_dict_or_None)``. Raises ------ ValueError On schema-version mismatch, strict-mode drift, strict-mode legacy artifacts, or after a successful compat repair under strict mode. """ try: obj = joblib.load(path) except ValueError as exc: if not str(exc).startswith(_DTYPE_MISMATCH_PREFIX): raise return _handle_compat_fallthrough( path, exc, strict, return_metadata) return _handle_loaded_object( obj, strict, return_metadata, after_compat=False, original_exc=None)
def _handle_loaded_object( obj: Any, strict: bool, return_metadata: bool, *, after_compat: bool, original_exc: BaseException | None) -> Any: """Route a loaded top-level object through envelope/legacy handlers. Shared by the non-fall-through path and by the post-`compat_load` path; the `after_compat` flag changes the strict-mode behavior (strict ALWAYS raises after a compat repair). """ if _is_envelope(obj): _validate_schema_version(obj) drift = _check_versions(obj) warn_drift = [d for d in drift if d[3] == "warn"] if strict: if after_compat: if warn_drift: raise ValueError( _format_strict_compat_drift_msg(warn_drift) ) from original_exc raise ValueError( "load(strict=True): artifact required compatibility " "repair (sklearn Tree dtype realignment) even " "though recorded dependency versions match the " "running environment. Re-save or re-fit the model " "to remove the dependency on compat_load." ) from original_exc if warn_drift: raise ValueError(_format_drift_msg(warn_drift)) else: if warn_drift: warnings.warn( _format_drift_msg(warn_drift), UserWarning, stacklevel=3) if after_compat: warnings.warn( "load(): recovered model after sklearn Tree dtype " "realignment. Re-save with save() to update the " "envelope, or re-fit on the current sklearn for " "the most reliable predictions.", UserWarning, stacklevel=3) model = obj["model"] if return_metadata: envelope_no_model = {k: v for k, v in obj.items() if k != "model"} return model, envelope_no_model return model # Raw legacy detector if strict: if after_compat: raise ValueError( "load(strict=True): legacy artifact (no envelope) " "required compatibility repair. Strict mode cannot " "verify a legacy artifact and cannot return a repaired " "model. Re-save with save() or re-fit." ) from original_exc raise ValueError( "load(strict=True): artifact is a raw legacy save with no " "envelope; strict mode requires an envelope produced by " "save(). Use strict=False to load anyway.") if after_compat: warnings.warn( "load(): recovered a legacy artifact after sklearn Tree " "dtype realignment. Re-save with save() to opt in to the " "versioned envelope going forward, or re-fit on the " "current sklearn.", UserWarning, stacklevel=3) if return_metadata: return obj, None return obj def _handle_compat_fallthrough( path: Any, original_exc: BaseException, strict: bool, return_metadata: bool) -> Any: try: obj = compat_load(path) except Exception as compat_exc: raise compat_exc from original_exc return _handle_loaded_object( obj, strict, return_metadata, after_compat=True, original_exc=original_exc) def _is_envelope(obj: Any) -> bool: return (isinstance(obj, dict) and "_pyod_persistence_version" in obj and "model" in obj) def _validate_schema_version(envelope: dict) -> None: v = envelope.get("_pyod_persistence_version") if not isinstance(v, int): raise ValueError( "load(): envelope has unsupported " f"_pyod_persistence_version={v!r}; expected an integer.") if v > _CURRENT_PERSISTENCE_VERSION: raise ValueError( f"load(): envelope schema version {v} is newer than this " f"PyOD release supports (max {_CURRENT_PERSISTENCE_VERSION}). " "Upgrade PyOD to read this artifact.") if v < 1: raise ValueError( f"load(): envelope schema version {v} is unrecognized.") # v in [1, _CURRENT_PERSISTENCE_VERSION] — supported by this release. def _check_versions(envelope: dict) -> list[tuple[str, str, str, str]]: """Return a list of (field, saved, running, severity) tuples for every recorded dependency that drifted from the running version.""" drift = [] for field, runner, severity in _VERSION_CHECKS: saved = envelope.get(field) if saved is None: continue try: running = runner() except Exception: # Optional dep missing at load time; treat as no drift. continue if saved != running: drift.append((field, saved, running, severity)) return drift def _format_drift_msg(drift: list[tuple[str, str, str, str]]) -> str: warn_rows = [d for d in drift if d[3] == "warn"] if not warn_rows: return "" parts = ", ".join( f"{field}={saved!r} (running {running!r})" for field, saved, running, _ in warn_rows) return ( "load(): dependency drift detected between saved envelope and " f"running environment: {parts}. Predictions may differ from " "what the model was trained to produce. Consider re-fitting on " "the current environment.") def _format_strict_compat_drift_msg( drift: list[tuple[str, str, str, str]]) -> str: warn_rows = [d for d in drift if d[3] == "warn"] parts = ", ".join( f"{field}={saved!r} (running {running!r})" for field, saved, running, _ in warn_rows) return ( "load(strict=True): artifact required compatibility repair " "(sklearn Tree dtype realignment) and recorded dependency " f"versions also drifted: {parts}. Re-save or re-fit the " "model.") # ---------------------------------------------------------------------- # compat_load # ----------------------------------------------------------------------
[docs] def compat_load(path: Any, mmap_mode: str | None = None) -> Any: """Load an artifact whose sklearn Tree node dtype no longer matches. Mirrors `joblib.load` but plugs a dispatch-table override into joblib's unpickler so that sklearn `Tree` state is realigned to the running sklearn dtype before `Tree.__setstate__` raises. Realignment is name-based and bounded by `_TREE_NODE_FIELD_DEFAULTS` plus `_TREE_NODE_FIELD_RENAMES`. Unknown added/removed fields, dtype kind/signedness/itemsize changes, and shape changes raise `ValueError`. Same-name byte-order-only differences realign safely. Emits a `UserWarning` recommending re-fit ONLY when at least one Tree was actually realigned. A no-op pass-through on a non-tree artifact is silent. Parameters ---------- path : str, pathlib.Path, or file-like The artifact to load. mmap_mode : str or None, default None Forwarded to joblib's underlying load path. Supported values mirror joblib's: None, 'r', 'r+', 'w+', 'c'. Returns ------- obj : Any The raw top-level object from the file (a fitted detector for legacy raw saves; an envelope dict for Phase 2 saves). Callers that need envelope unwrapping should use `load()`. """ trees_realigned = [0] class _CompatNumpyUnpickler(NumpyUnpickler): dispatch = NumpyUnpickler.dispatch.copy() def load_build(self): if len(self.stack) >= 2: state = self.stack[-1] inst = self.stack[-2] if _is_sklearn_tree(inst) and isinstance(state, dict): new_state = _maybe_realign_tree_state(state) if new_state is not state: self.stack[-1] = new_state trees_realigned[0] += 1 return super().load_build() _CompatNumpyUnpickler.dispatch[pickle.BUILD[0]] = ( _CompatNumpyUnpickler.load_build) obj = _load_with_unpickler(path, _CompatNumpyUnpickler, mmap_mode) if trees_realigned[0] > 0: warnings.warn( f"compat_load: realigned {trees_realigned[0]} sklearn " "Tree(s) to the current sklearn dtype. Predictions on " "inputs WITH missing values may differ from what the " "original model would have produced because zero-filled " "defaults for newly-added node fields may not match the " "original training behavior. Re-fit on the current sklearn " "is recommended for the most reliable predictions.", UserWarning, stacklevel=2) return obj
def _load_with_unpickler( path: Any, unpickler_cls: type, mmap_mode: str | None) -> Any: """Mirror of joblib.load that swaps in a custom unpickler class. Re-uses joblib's `_validate_fileobject_and_memmap` so compressed files and mmap_mode follow the same code path as `joblib.load`. Requires joblib >= 1.5 because earlier versions lacked `_validate_fileobject_and_memmap` and used a different `NumpyUnpickler` constructor signature. """ try: from joblib.numpy_pickle import ( _validate_fileobject_and_memmap, load_compatibility) except ImportError as exc: raise ImportError( "compat_load requires joblib>=1.5 because it reuses " "joblib's validated file-object and mmap loader, which were " "added in joblib 1.5. Upgrade joblib, or use joblib.load " "for artifacts that do not need sklearn Tree dtype repair." ) from exc # Mirror joblib.load's normalization of Path and file-like input. if isinstance(path, Path): path = str(path) ensure_native_byte_order = mmap_mode is None def _run(file_handle, filename, validated_mmap_mode): if isinstance(file_handle, str): # Joblib pre-0.10 legacy format path. return load_compatibility(file_handle) unpickler = unpickler_cls( filename, file_handle, ensure_native_byte_order, mmap_mode=validated_mmap_mode) return unpickler.load() if hasattr(path, "read"): fobj = path filename = getattr(fobj, "name", "") with _validate_fileobject_and_memmap( fobj, filename, mmap_mode) as (fh, validated): return _run(fh, filename, validated) with open(path, "rb") as f: with _validate_fileobject_and_memmap( f, path, mmap_mode) as (fh, validated): return _run(fh, path, validated) def _is_sklearn_tree(inst: Any) -> bool: """True iff `inst` is `sklearn.tree._tree.Tree`. Imported lazily so loading non-tree artifacts does not require sklearn to be importable on this code path.""" try: from sklearn.tree._tree import Tree except Exception: return False return isinstance(inst, Tree) def _current_tree_node_dtype() -> np.dtype: """Discover the running sklearn's Tree node dtype dynamically. Reads `sklearn.tree._tree.NODE_DTYPE` when available (sklearn >= 1.0). Falls back to introspecting an empty `Tree` instance. """ from sklearn.tree import _tree if hasattr(_tree, "NODE_DTYPE"): return _tree.NODE_DTYPE n_classes = np.array([1], dtype=np.intp) t = _tree.Tree(1, n_classes, 1) return t.nodes.dtype def _maybe_realign_tree_state(state: dict) -> dict: """Return `state` unchanged if no realignment is needed, otherwise return a new state dict with the `nodes` ndarray realigned to the running sklearn dtype. Raises `ValueError` on unsafe differences. """ nodes = state.get("nodes") if not isinstance(nodes, np.ndarray): return state current = _current_tree_node_dtype() if nodes.dtype == current: return state saved_names = set(nodes.dtype.names or ()) current_names = set(current.names or ()) # Resolve recognized renames first. A field that is "added" from # current's perspective is treated as truly new only if no known # rename maps a removed saved field to it; otherwise the rename # carries the data forward and the default-required check does not # apply. raw_added = current_names - saved_names raw_removed = saved_names - current_names rename_map: dict[str, str] = {} for old_name in raw_removed: target = _TREE_NODE_FIELD_RENAMES.get(old_name) if target is None or target not in current_names: raise ValueError( "compat_load: saved Tree node dtype has field " f"{old_name!r} which the running sklearn does not " "recognize. Dropping it could change predictions. " "Re-fit on the current sklearn.") rename_map[old_name] = target rename_targets = set(rename_map.values()) # Fields genuinely added in current (not produced by any rename). added = raw_added - rename_targets for name in added: if name not in _TREE_NODE_FIELD_DEFAULTS: raise ValueError( "compat_load: saved Tree node dtype is missing field " f"{name!r} which the running sklearn requires. No " "default is registered for this field, so silent " "zero-fill is unsafe. Re-fit the model on the current " "sklearn or add a compatibility entry to " "_TREE_NODE_FIELD_DEFAULTS with the correct default.") # Shared field dtypes must match modulo byte order. shared = saved_names & current_names for name in shared: if not _dtypes_compatible_modulo_endian( nodes.dtype.fields[name][0], current.fields[name][0]): raise ValueError( "compat_load: saved Tree node field " f"{name!r} has dtype {nodes.dtype.fields[name][0]!r} " f"but the running sklearn expects " f"{current.fields[name][0]!r}. Anything beyond a byte-" "order difference (kind, signedness, itemsize, shape) " "could change predictions. Re-fit on the current " "sklearn.") # Realign. new = np.zeros(len(nodes), dtype=current) for name in shared: new[name] = nodes[name] for old_name, new_name in rename_map.items(): new[new_name] = nodes[old_name] for name in added: new[name] = _TREE_NODE_FIELD_DEFAULTS[name] return {**state, "nodes": new} def _dtypes_compatible_modulo_endian( a: np.dtype, b: np.dtype) -> bool: """True iff `a` and `b` differ at most in byte order.""" if a == b: return True # Compare the str representation with byte-order stripped. return (a.kind == b.kind and a.itemsize == b.itemsize and a.shape == b.shape and a.subdtype == b.subdtype)