# -*- coding: utf-8 -*-
"""Using AutoEncoder with Outlier Detection (PyTorch)
"""
# Author: Yue Zhao <zhaoy@cmu.edu>
# License: BSD 2 clause
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted
from torch import nn
from .base import BaseDetector
from ..utils.stat_models import pairwise_distances_no_broadcast
from ..utils.torch_utility import get_activation_by_name
[docs]class PyODDataset(torch.utils.data.Dataset):
"""PyOD Dataset class for PyTorch Dataloader
"""
def __init__(self, X, y=None, mean=None, std=None):
super(PyODDataset, self).__init__()
self.X = X
self.mean = mean
self.std = std
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = self.X[idx, :]
if self.mean is not None and self.std is not None:
sample = (sample - self.mean) / self.std
# assert_almost_equal (0, sample.mean(), decimal=1)
return torch.from_numpy(sample), idx
[docs]class InnerAutoencoder(nn.Module):
def __init__(self,
n_features,
hidden_neurons=(128, 64),
dropout_rate=0.2,
batch_norm=True,
hidden_activation='relu'):
# initialize the super class
super(InnerAutoencoder, self).__init__()
# save the default values
self.n_features = n_features
self.dropout_rate = dropout_rate
self.batch_norm = batch_norm
self.hidden_activation = hidden_activation
# create the dimensions for the input and hidden layers
self.layers_neurons_encoder_ = [self.n_features, *hidden_neurons]
self.layers_neurons_decoder_ = self.layers_neurons_encoder_[::-1]
# get the object for the activations functions
self.activation = get_activation_by_name(hidden_activation)
# initialize encoder and decoder as a sequential
self.encoder = nn.Sequential()
self.decoder = nn.Sequential()
# fill the encoder sequential with hidden layers
for idx, layer in enumerate(self.layers_neurons_encoder_[:-1]):
# create a linear layer of neurons
self.encoder.add_module(
"linear" + str(idx),
torch.nn.Linear(layer,self.layers_neurons_encoder_[idx + 1]))
# add a batch norm per layer if wanted (leave out first layer)
if batch_norm:
self.encoder.add_module("batch_norm" + str(idx),
nn.BatchNorm1d(
self.layers_neurons_encoder_[
idx + 1]))
# create the activation
self.encoder.add_module(self.hidden_activation + str(idx),
self.activation)
# create a dropout layer
self.encoder.add_module("dropout" + str(idx),
torch.nn.Dropout(dropout_rate))
# fill the decoder layer
for idx, layer in enumerate(self.layers_neurons_decoder_[:-1]):
# create a linear layer of neurons
self.decoder.add_module(
"linear" + str(idx),
torch.nn.Linear(layer,self.layers_neurons_decoder_[idx + 1]))
# create a batch norm per layer if wanted (only if it is not the
# last layer)
if batch_norm and idx < len(self.layers_neurons_decoder_[:-1]) - 1:
self.decoder.add_module("batch_norm" + str(idx),
nn.BatchNorm1d(
self.layers_neurons_decoder_[
idx + 1]))
# create the activation
self.decoder.add_module(self.hidden_activation + str(idx),
self.activation)
# create a dropout layer (only if it is not the last layer)
if idx < len(self.layers_neurons_decoder_[:-1]) - 1:
self.decoder.add_module("dropout" + str(idx),
torch.nn.Dropout(dropout_rate))
[docs] def forward(self, x):
# we could return the latent representation here after the encoder
# as the latent representation
x = self.encoder(x)
x = self.decoder(x)
return x
[docs]class AutoEncoder(BaseDetector):
"""Auto Encoder (AE) is a type of neural networks for learning useful data
representations in an unsupervised manner. Similar to PCA, AE could be used
to detect outlying objects in the data by calculating the reconstruction
errors. See :cite:`aggarwal2015outlier` Chapter 3 for details.
Notes
-----
This is the PyTorch version of AutoEncoder. See auto_encoder.py for
the TensorFlow version.
The documentation is not finished!
Parameters
----------
hidden_neurons : list, optional (default=[64, 32])
The number of neurons per hidden layers. So the network has the
structure as [n_features, 64, 32, 32, 64, n_features]
hidden_activation : str, optional (default='relu')
Activation function to use for hidden layers.
All hidden layers are forced to use the same type of activation.
See https://pytorch.org/docs/stable/nn.html for details.
Currently only
'relu': nn.ReLU()
'sigmoid': nn.Sigmoid()
'tanh': nn.Tanh()
are supported. See pyod/utils/torch_utility.py for details.
batch_norm : boolean, optional (default=True)
Whether to apply Batch Normalization,
See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
learning_rate : float, optional (default=1e-3)
Learning rate for the optimizer. This learning_rate is given to
an Adam optimizer (torch.optim.Adam).
See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
epochs : int, optional (default=100)
Number of epochs to train the model.
batch_size : int, optional (default=32)
Number of samples per gradient update.
dropout_rate : float in (0., 1), optional (default=0.2)
The dropout to be used across all layers.
weight_decay : float, optional (default=1e-5)
The weight decay for Adam optimizer.
See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
preprocessing : bool, optional (default=True)
If True, apply standardization on the data.
loss_fn : obj, optional (default=torch.nn.MSELoss)
Optimizer instance which implements torch.nn._Loss.
One of https://pytorch.org/docs/stable/nn.html#loss-functions
or a custom loss. Custom losses are currently unstable.
verbose : int, optional (default=1)
Verbosity mode.
- 0 = silent
- 1 = progress bar
- 2 = one line per epoch.
For verbose >= 1, model summary may be printed.
!CURRENTLY NOT SUPPORTED.!
random_state : random_state: int, RandomState instance or None, optional
(default=None)
If int, random_state is the seed used by the random
number generator; If RandomState instance, random_state is the random
number generator; If None, the random number generator is the
RandomState instance used by `np.random`.
!CURRENTLY NOT SUPPORTED.!
contamination : float in (0., 0.5), optional (default=0.1)
The amount of contamination of the data set, i.e.
the proportion of outliers in the data set. When fitting this is used
to define the threshold on the decision function.
Attributes
----------
encoding_dim_ : int
The number of neurons in the encoding layer.
compression_rate_ : float
The ratio between the original feature and
the number of neurons in the encoding layer.
model_ : Keras Object
The underlying AutoEncoder in Keras.
history_: Keras Object
The AutoEncoder training history.
decision_scores_ : numpy array of shape (n_samples,)
The outlier scores of the training data.
The higher, the more abnormal. Outliers tend to have higher
scores. This value is available once the detector is
fitted.
threshold_ : float
The threshold is based on ``contamination``. It is the
``n_samples * contamination`` most abnormal samples in
``decision_scores_``. The threshold is calculated for generating
binary outlier labels.
labels_ : int, either 0 or 1
The binary labels of the training data. 0 stands for inliers
and 1 for outliers/anomalies. It is generated by applying
``threshold_`` on ``decision_scores_``.
"""
def __init__(self,
hidden_neurons=None,
hidden_activation='relu',
batch_norm=True,
learning_rate=1e-3,
epochs=100,
batch_size=32,
dropout_rate=0.2,
weight_decay=1e-5,
# validation_size=0.1,
preprocessing=True,
loss_fn=None,
# verbose=1,
# random_state=None,
contamination=0.1,
device=None):
super(AutoEncoder, self).__init__(contamination=contamination)
# save the initialization values
self.hidden_neurons = hidden_neurons
self.hidden_activation = hidden_activation
self.batch_norm = batch_norm
self.learning_rate = learning_rate
self.epochs = epochs
self.batch_size = batch_size
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.preprocessing = preprocessing
self.loss_fn = loss_fn
# self.verbose = verbose
self.device = device
# create default loss functions
if self.loss_fn is None:
self.loss_fn = torch.nn.MSELoss()
# create default calculation device (support GPU if available)
if self.device is None:
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu")
# default values for the amount of hidden neurons
if self.hidden_neurons is None:
self.hidden_neurons = [64, 32]
# noinspection PyUnresolvedReferences
[docs] def fit(self, X, y=None):
"""Fit detector. y is ignored in unsupervised methods.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.
y : Ignored
Not used, present for API consistency by convention.
Returns
-------
self : object
Fitted estimator.
"""
# validate inputs X and y (optional)
X = check_array(X)
self._set_n_classes(y)
n_samples, n_features = X.shape[0], X.shape[1]
# conduct standardization if needed
if self.preprocessing:
self.mean, self.std = np.mean(X, axis=0), np.std(X, axis=0)
train_set = PyODDataset(X=X, mean=self.mean, std=self.std)
else:
train_set = PyODDataset(X=X)
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=self.batch_size,
shuffle=True)
# initialize the model
self.model = InnerAutoencoder(
n_features=n_features,
hidden_neurons=self.hidden_neurons,
dropout_rate=self.dropout_rate,
batch_norm=self.batch_norm,
hidden_activation=self.hidden_activation)
# move to device and print model information
self.model = self.model.to(self.device)
print(self.model)
# train the autoencoder to find the best one
self._train_autoencoder(train_loader)
self.model.load_state_dict(self.best_model_dict)
self.decision_scores_ = self.decision_function(X)
self._process_decision_scores()
return self
def _train_autoencoder(self, train_loader):
"""Internal function to train the autoencoder
Parameters
----------
train_loader : torch dataloader
Train data.
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.learning_rate,
weight_decay=self.weight_decay)
self.best_loss = float('inf')
self.best_model_dict = None
for epoch in range(self.epochs):
overall_loss = []
for data, data_idx in train_loader:
data = data.to(self.device).float()
loss = self.loss_fn(data, self.model(data))
self.model.zero_grad()
loss.backward()
optimizer.step()
overall_loss.append(loss.item())
print('epoch {epoch}: training loss {train_loss} '.format(
epoch=epoch, train_loss=np.mean(overall_loss)))
# track the best model so far
if np.mean(overall_loss) <= self.best_loss:
# print("epoch {ep} is the current best; loss={loss}".format(ep=epoch, loss=np.mean(overall_loss)))
self.best_loss = np.mean(overall_loss)
self.best_model_dict = self.model.state_dict()
[docs] def decision_function(self, X):
"""Predict raw anomaly score of X using the fitted detector.
The anomaly score of an input sample is computed based on different
detector algorithms. For consistency, outliers are assigned with
larger anomaly scores.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The training input samples. Sparse matrices are accepted only
if they are supported by the base estimator.
Returns
-------
anomaly_scores : numpy array of shape (n_samples,)
The anomaly score of the input samples.
"""
check_is_fitted(self, ['model', 'best_model_dict'])
X = check_array(X)
# note the shuffle may be true but should be False
if self.preprocessing:
dataset = PyODDataset(X=X, mean=self.mean, std=self.std)
else:
dataset = PyODDataset(X=X)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=self.batch_size,
shuffle=False)
# enable the evaluation mode
self.model.eval()
# construct the vector for holding the reconstruction error
outlier_scores = np.zeros([X.shape[0], ])
with torch.no_grad():
for data, data_idx in dataloader:
data_cuda = data.to(self.device).float()
# this is the outlier score
outlier_scores[data_idx] = pairwise_distances_no_broadcast(
data, self.model(data_cuda).cpu().numpy())
return outlier_scores