SurvivalDNN Usage Example¶

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
np.random.seed(42)

Data¶

In this simple example, suppose a widget factory produces widgets of various length, width, and widget factor. We observe these widget-level characteristics as features $X$, as well as their failure times $Y$. The goal is to draw inferences or make estimates regarding the distribution $P(Y > t|X)$.

In [2]:
def generate_synthetic_data(n=100):
    
    # Observed features
    length = np.random.uniform(1., 8., size=n)
    width = np.random.uniform(0.5, 4., size=n)
    widget_factor = np.random.uniform(0.2, 2., size=n)
    X = np.stack([length, width, widget_factor], axis=-1)

    # True DGP (Unobserved)
    epsilon = np.random.normal(0, 0.1,size=n)
    U = 1 - 0.1*length - 0.1*width - widget_factor + epsilon
    rate = np.exp(U)

    # Observed outcomes
    Y = np.exp(-rate) * 10
    
    return X, Y
    
X, Y = generate_synthetic_data(n=10000)
In [3]:
plt.hist(Y)
plt.title('Distribution of Widget Time-Until-Failure')
plt.xlabel('Years')
plt.show()
No description has been provided for this image

Model¶

In [4]:
from survivaldnn import SurvivalDNNModel
In [5]:
model = SurvivalDNNModel()

Discretization¶

One important consideration is numSupport, which determines the number of support points used to discretize the outcome space. The output of the neural network is of the form $$ P(Y > t_1|X), P(Y > t_2|X), ..., P(Y > t_{numSupport} | X) $$ If it is too small, too many observations are combined into one bin and we lose precision. If it is too large, there may be very few observations in some bins and we lose efficiency or get noisy estimates.

In [6]:
# Internally, this is used to determine the support points
# t_1, t_2, ..., t_{numSupport}
support = model.discretize_outcome_support(Y, numSupport=10)
support
Out[6]:
array([1.32513751, 3.83058098, 4.6781604 , 5.39642685, 6.00376299,
       6.57914055, 7.09105098, 7.54190942, 7.99987867, 9.00493482])

Compilation¶

Next we compile the neural network model by specifying its architecture, regularization, loss function, and other hyperparameters. The default is a fully-connected residual network with batch normalization, leaky ReLU activations, and ADAM optimizer. For additional documentation, see help(model.compile).

In [ ]:
numFeatures = X.shape[-1]
model.compile(numFeatures=numFeatures,
              numSupport=100,
              loss='loglik',
              architecture='resnet',
              layers=5)

Train¶

Finally we can train the model using the dataset. The previous step in this example specified that we are minimizing negative loglikelihood. For more prediction-focused tasks where importance is more on point-estimates, alternative loss functions include mse or custom callables.

In [8]:
# Standardize features
X = (X - X.mean(axis=0)) / X.std(axis=0)
In [9]:
# Train the model
model.fit(X, Y,
          epochs=2000)
  0%|          | 0/2000 [00:00<?, ?it/s]

Inference¶

After training, we can then estimate statistics of interest from the learned distribution $\hat{P}(Y > t|X)$.

Point Estimates/Predictions¶

Estimation of $\hat{\mathbb{E}}_n[Y|X=x]$ can be done with .predict.

In the widget factory application, this is the expected failure time for a given widget with features $x$.

In [10]:
Y_hat = model.predict(X)
In [11]:
def compare_predictions(res):
    fig, ax = plt.subplots(figsize=(12,7))
    ax.set_xlim((0,10))
    ax.set_ylim((0,10))
    res.plot.scatter(x='True', y='Predicted', ax=ax)
    a, b = res.min().max(), res.max().min()
    plt.plot([a, b], [a, b], 'r--', lw=2, label='Perfect Prediction')
    plt.legend()
    plt.show()
In [12]:
res = pd.DataFrame(np.stack([Y, Y_hat], axis=-1), columns=['True', 'Predicted'])
compare_predictions(res)
No description has been provided for this image

Conditional Point Estimates/Predictions¶

In some scenarios, we know that at least $t$ time has passed before failure which we may wish to incorporate to improve estimates. Estimation of $\hat{E}_n[Y|X=x, Y>t]$ can be done with .predict_conditional.

In the widget factory application, this is the expected failure time for a given widget with features $x$ given that at least $t$ time has already passed.

In [13]:
# Suppose we inspected widgets halfway and
# observed they haven't failed yet
elapsed = Y / 2
Y_hat_cond = model.predict_conditional(X, elapsed)
In [14]:
res = pd.DataFrame(np.stack([Y, Y_hat_cond], axis=-1), columns=['True', 'Predicted'])
compare_predictions(res)
No description has been provided for this image

Survival Function¶

The survival function $S(t|X) = P(Y >= t|X)$ can be estimated and plotted with .predict_survival_function and .plot_survival_function.

In [15]:
survFunc, support = model.predict_survival_function(X[:100,:])
In [16]:
model.plot_survival_function(X[:100,:])
No description has been provided for this image

Cumulative Distribution Function¶

The closely related CDF $P(Y < t| X)$ can be plotted with .plot_distribution.

In [17]:
model.plot_distribution(X[:100,:])
No description has been provided for this image