"""Classification based on custom functions (e.g., rule-based classification)."""
__all__ = ("FallbackRuleClassifier", "RuleClassifier")
import abc
import warnings
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import accuracy_score
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_is_fitted
from ..core.exceptions import SKFBWarning
from ..utils._legacy import _fit_context, validate_params
from .base import RejectorMixin
class RuleClassificationWarning(SKFBWarning):
"""Raised if validation of RuleClassifier and its subclasses causes errors."""
[docs]class RuleClassifier(BaseEstimator, ClassifierMixin, metaclass=abc.ABCMeta):
"""ABC that defines rule-based classification.
Reimplement ``_predict`` by defining custom classification rules (e.g.,
label a transaction fraudulent if a receiver is in a blacklist; otherwise,
it's genuine). You can introduce new rule arguments either via reimplemented
``__init__`` or by passing them as ``kwargs`` during model instatiation.
Parameters
----------
validate : bool, default=False
Whether to validate reimplemented prediction method. Validation means checking
if prediction doesn't result in any exception.
kwargs : dict, default=None
Any (tunable) argument required to make predictions.
Each argument is returned and can be set individually by ``get_params`` and
``set_params``, respectively.
"""
_parameter_constraints = {"validate": [bool], "kwargs": [None, dict]}
def __init__(self, *, validate=False, kwargs=None):
self.validate = validate
self.kwargs = kwargs
[docs] @_fit_context(prefer_skip_nested_validation=False)
@validate_params(
{
"X": ["array-like", "sparse matrix"],
"y": ["array-like", None],
"classes": ["array-like", None],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def fit(self, X, y=None, sample_weight=None):
"""Fits the estimator and sets fit attributes.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,) or (n_samples, n_outputs)
The target values.
Returns
-------
self : object
Returns self.
"""
if y is not None:
self.classes_ = unique_labels(y)
self._fit(X, y, sample_weight=sample_weight)
self.is_fitted_ = True
return self
# pylint: disable=unused-argument
def _fit(self, X, y, sample_weight=None):
"""Only fits the estimator.
Should be reimplemented if a rule requires a learning mechanism.
"""
if self.validate:
self.validate_predict(X)
return self
[docs] @_fit_context(prefer_skip_nested_validation=False)
@validate_params(
{
"X": ["array-like", "sparse matrix"],
"y": ["array-like", None],
"classes": ["array-like", None],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def partial_fit(self, X, y, classes=None, sample_weight=None):
"""Fits the estimator partially."""
if classes is not None:
self.classes_ = classes
self._partial_fit(X, y, classes=classes, sample_weight=sample_weight)
self.is_fitted_ = True
return self
# pylint: disable=unused-argument
def _partial_fit(self, X, y=None, classes=None, sample_weight=None):
"""Only fits the estimator partially.
Should be reimplemented if a rule requires a learning mechanism.
"""
return self
[docs] @validate_params(
{
"X": ["array-like", "sparse matrix"],
},
prefer_skip_nested_validation=True,
)
def validate_predict(self, X):
"""Validates inference methods.
Raises
------
An exception raised by one of the methods.
"""
try:
self.is_fitted_ = True
self.predict(X)
except Exception:
warnings.warn(
(
"Validation of {self.__class__.__name__}.predict resulted in"
" errors; please, check your implementations"
),
category=RuleClassificationWarning,
)
delattr(self, "is_fitted_")
raise
[docs] @validate_params(
{
"X": ["array-like", "sparse matrix"],
},
prefer_skip_nested_validation=True,
)
def predict(self, X):
"""Predicts hard labels.
Parameters
----------
X : indexable, length n_samples
Input samples to classify.
Returns
-------
y_pred : ndarray of shape (n_samples,)
Predicted hard labels.
"""
check_is_fitted(self, attributes="is_fitted_")
if self.validate:
check_is_fitted(self, attributes="is_fitted_")
return np.asarray(self._predict(X))
@abc.abstractmethod
def _predict(self, X):
"""An abstract method to make rule-based predictions.
Should be reimplemented. If applicable, use ``self.kwargs`` as
additional input parameters to your rule.
Parameters
----------
X : indexable, length n_samples
Input samples to classify.
"""
[docs] def get_params(self, deep=True):
"""Gets parameters for the estimator."""
parameters = super().get_params(deep)
parameters |= {**(self.kwargs or {})}
return parameters
[docs] def set_params(self, **params):
"""Sets the parameters of the estimator."""
if self.kwargs:
for key, value in params.items():
if key in self.kwargs:
self.kwargs[key] = value
return super().set_params(**params)
[docs]class FallbackRuleClassifier(RuleClassifier, RejectorMixin):
"""Rule-based fallback classification.
Reimplement ``_predict`` by defining custom classification rules, including
fallbacks (e.g., label a transaction fraudulent if a receiver is in a blacklist;
otherwise, it's either genuine for regular transactions or anomalous for irregular
ones). You can introduce new rule arguments either via reimplemented ``__init__``
or by passing them as ``kwargs`` during model instatiation.
Parameters
----------
validate : bool, default=False
Whether to validate reimplemented prediction method. Validation means checking
if prediction doesn't result in any exception.
fallback_label : any, default=-1
Label returned by fallback rules.
kwargs : dict, default=None
Any (tunable) argument required to make predictions.
Each argument is returned and can be set individually by ``get_params`` and
``set_params``, respectively.
"""
def __init__(self, *, validate=False, fallback_label=-1, kwargs=None):
super().__init__(validate=validate, kwargs=kwargs)
self.fallback_label = fallback_label
def _fit(self, X, y, sample_weight=None):
"""Validates inputs if option is True and sets initial attributes."""
if y is not None and hasattr(self, "classes_"):
self.fallback_label_ = self.validate_fallback_label(
self.fallback_label, self.classes_
)
return super()._fit(X, y, sample_weight)
def _partial_fit(self, X, y=None, classes=None, sample_weight=None):
if y is not None and classes is not None:
self.fallback_label_ = self.validate_fallback_label(
self.fallback_label, classes
)
return super()._partial_fit(X, y, classes, sample_weight)
[docs] @_fit_context(prefer_skip_nested_validation=False)
@validate_params(
{
"X": ["array-like", "sparse matrix"],
"y": ["array-like"],
},
prefer_skip_nested_validation=True,
)
def score(self, X, y):
"""Evaluates an accuracy score on accepted samples.
Parameters
----------
X : indexable, length n_samples
Input samples to evaluate.
y : array-like of shape (n_samples,)
True labels for `X` (excluding fallback label).
Returns
-------
score : float
Accuracy on samples not labeled as fallbacks.
See also
--------
skfb.metrics.prediction_quality
"""
# ??? Is it the right way to overcome circular imports?
# pylint: disable=import-outside-toplevel
from ..metrics._common import prediction_quality
check_is_fitted(self, attributes="fallback_label_")
y_pred = self.predict(X)
return prediction_quality(y, y_pred, accuracy_score, self.fallback_label_)