paint-brush
Creating visualizations to better understand your data and models (Part 1)by@gaurav_bio
4,606 reads
4,606 reads

Creating visualizations to better understand your data and models (Part 1)

by Gaurav KaushikSeptember 27th, 2018
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

One of my favorite things about being a data scientist is creating new tools that make it easier to interpret data and models. I especially like to think about new ways to visualize data that could help solve a tough problem and be useful to my team. Visualizations and interactive interfaces tend to elevate how we work, and short-term investment in new tooling accelerate our analyses and enhance our understanding of the data.

People Mentioned

Mention Thumbnail

Companies Mentioned

Mention Thumbnail
Mention Thumbnail
featured image - Creating visualizations to better understand your data and models (Part 1)
Gaurav Kaushik HackerNoon profile picture

One of my favorite things about being a data scientist is creating new tools that make it easier to interpret data and models. I especially like to think about new ways to visualize data that could help solve a tough problem and be useful to my team. Visualizations and interactive interfaces tend to elevate how we work, and short-term investment in new tooling accelerate our analyses and enhance our understanding of the data.

A while back, I shared screenshots on Twitter of some tools I was building and the feedback was really positive (on Twitter — imagine that!). Recently, I stumbled upon my old tweet and decided to polish up my code and share them with the community.

This is the first post in a three-part series about creating visualizations for dissecting data and models. Code for Part 1 is now on GitHub, which includes a Jupyter Notebook that walks through each visualization and tool as well as accompanying command-line-executable code. Keep an eye on that repository or Twitter for Parts 2 and 3.

body[data-twttr-rendered="true"] {background-color: transparent;}.twitter-tweet {margin: auto !important;}

Finally picked up @BokehPlots, wished I had sooner. Polishing up some test graphs -- next up, exploring actual RNA Seq data. #dataviz

— @gaurav_bio

function notifyResize(height) {height = height ? height : document.documentElement.offsetHeight; var resized = false; if (window.donkey && donkey.resize) {donkey.resize(height); resized = true;}if (parent && parent._resizeIframe) {var obj = {iframe: window.frameElement, height: height}; parent._resizeIframe(obj); resized = true;}if (window.location && window.location.hash === "#amp=1" && window.parent && window.parent.postMessage) {window.parent.postMessage({sentinel: "amp", type: "embed-size", height: height}, "*");}if (window.webkit && window.webkit.messageHandlers && window.webkit.messageHandlers.resize) {window.webkit.messageHandlers.resize.postMessage(height); resized = true;}return resized;}twttr.events.bind('rendered', function (event) {notifyResize();}); twttr.events.bind('resize', function (event) {notifyResize();});if (parent && parent._resizeIframe) {var maxWidth = parseInt(window.frameElement.getAttribute("width")); if ( 500 < maxWidth) {window.frameElement.setAttribute("width", "500");}}

Principal Component Analysis

A stalwart of the data scientist toolkit is the Principal Component Analysis (PCA), which transforms data into linearly, uncorrelated features or principal components (PCs). Importantly, these new features are meaningful — they indicate axes in the data where variance the variance is greatest. We can reason that variance is where much of the ‘signal’ or information in the data is kept, because it can explain why different groups in the data are different or why an outcome that can be predicted from the data can vary. Variance is vital.

The first principal component, or PC1, ‘explains the most variance’ in the original dataset — this also means that features that correlate with PC1 contribute to a large amount of variance in the data.

As you go from PC1 to the final PC, the amount of variance explained decreases while cumulative explained variance approaches 100%. If you wish to reduce the number of dimensions for analysis, you can choose a specific number of PCs based on how much variance is explained (e.g. 95%). You can also understand how much variance is explained in just two or three easy-to-visualize dimensions.

When analyzing data with a lot of variables, I like to understand how many and which features are useful. I work in genomics, where you can easily have thousands of features per instance and understanding how many features I need makes it easier to train and interpret models on that data. In genomics, features are cheap and instances are precious. For this reason, I often use PCA as a guide to selecting which features to focus on.

Obligatory figure showing (left) random data in two dimensions and (right) PC1 and PC2 of the data. Black lines indicate the axes of along which the most variance exists. Note that these axes are orthogonal. In PCA space (right), the axes are vectors along PC1 and PC2.

Interactive PCA

The first visualization I want to walk through helps me quickly unpack a PCA analysis and share with others. Specifically, it plots the explained variance per PC as well as the cumulative explained variance as you go from PC1 to the final PC.

Here’s an example output, generated with the classic iris dataset:

<a href="https://medium.com/media/ba9d114b2185ef5da3f4c282cf465cde/href">https://medium.com/media/ba9d114b2185ef5da3f4c282cf465cde/href</a>

If the embed above doesn’t work for you, here’s a screenshot:

Screenshot of an interactive visualization for PCA. Seagreen bars indicate explained variance for that PC. The orange line shows the cumulative variance over N PCs. The bold dashed line indicates the number of PCs at which explained variance exceeds 95% (or 0.95). The tool also includes dashed lines for N=2 and N=3, to quickly understand how much variance is explained within easy-to-visualize dimensions.

To create this tool, I used scikit-learn for the PCA analysis and bokeh for interactive visualization.

Performing a PCA analysis with scikit-learn is a snap:

from sklearn.decomposition import PCA

# create 2D PCA space from original data (X)
pca = PCA(2)
X_pca = pca.fit_transform(X)

From pca, we can unpack the explained variance per PC as a ratio from 0 to 1 and calculate the cumulative explained variance from PC1 to the final PC:

# Explained variance per PC
pca.explained_variance_ratio_

# Cumulative explained variance
np.cumsum(pca_evr)

Finally, we can find the first PC at which >95% of the variance in the data is explained, and the explained variance ratio for the first 2 and 3 components:

# PC where >95% variance explained
np.argmax(cumsum_ >= 0.95) + 1

# Variance explained with first 2 PCs
cumsum_[1]

# Variance explained with first 3 PCs
cumsum_[2]

With these variables, we have the necessary information to generate the above plot in bokeh.

For this post, I won’t get too deep into how bokeh works (I’ll save that for a future post) — instead, I’ll highlight the most critical components of the code provided on GitHub.

The crux of using bokeh is understanding the ColumnDataSource. This object allows you to package up data and labels together for use in various bokeh plots. For a 2D lineplot, for example, you can create a ColumnDataSource with simple dictionary:

{
    'x': x_values,
    'y': y_values,
    'labels': labels
}

In this dictionary, the values are all lists or arrays of the same length. The dictionary can also be passed directly to ColumnDataSource:

ColumnDataSource(data=dict(x=x_values, y=y_values, labels=labels))

The ColumnDataSource is a useful intermediate step because it can then be reused for multiple plots. In the case of a line plot where we want to create a circle at each point, we can create one ColumnDataSource and reuse it for both plots:

from bokeh.plotting import figure, ColumnDataSource, show

# create ColumnDataSource
cds = ColumnDataSource(data=dict(x=x_values, y=y_values, labels=labels))

# instantiate plot
p = figure(title='PCA Analysis')

# line plot - 'x' and 'y' refer to keys in the source (cds)
p.line('x', 'y', line_width=1, color='#F79737', source=cds)

# circle plot
p.circle('x', 'y', size=10, color='#FF4C00', source=cds)

# show plot
p.show()

The code repository goes into greater detail, including customizing the bokeh toolbar, adding labels when hovering over data points, and more.

Understanding How Features Contribute to PCs

One benefit to transforming data to ‘PCA-space’ is that it allows you to identify features that are the most informative. The previous tool allows you to better understand the overall structure of the data, but doesn’t reveal which features contribute to the PCs with the most information.

One method for understanding which features are ‘important’ is to examine how each feature contributes to each principal component. To do this, we can take the dot product of our original data and our principal components.

The dot product of two vectors is the product of their magnitudes and the cosine of the angle between them. When two vectors are orthogonal, the dot product is zero, since _cos(_90°)=0. When two vectors lie along the same axis, the dot product is the product of their magnitudes.

Assuming our data is rescaled, the relative magnitudes of its dot product with the principal components will indicate the co-linearity or correlation of individual features and PCs. In other words, if a feature is nearly co-linear with a PC, the magnitude of the dot product will be relatively large.

The dot product we take is with the transpose of our original dataset and its PCA, which reveals an NxN matrix (where N is the number of features/PCs).

The dot product of the transpose of the original dataset (X_T) and its PCA (X_pca) generates a new square matrix that informs how each feature correlates to each principal component.

For the iris dataset, this results in the following heatmap (for easier visualization, I also Z-normalize each column in the resulting matrix):

Z-normalized correlation matrix between the original iris dataset and PC-space.

Above, we can see that petal length and sepal width correlate strongly with PC1.

The feedback I received on this heatmap was that it took a while to interpret. Co-workers suggested a visualization that would make it easier to immediately understand which features helped explain the most variance in the dataset. For iris, you could interrogate each cell one-by-one, but at scale we’ll need a better visualization.

To address this, I decided to re-normalize the heatmap with the explained variance of each PC. I essentially took the dot product of our new matrix and the explained variance as a vector. This would immediately reveal not only how each feature correlates with each PC, but how they contribute to the variance in the dataset.

For iris, we get the following heatmap, which I think is much easier to grok:

Z-normalized correlation matrix between the original iris dataset and PC-space, normalized by explained variance. The significance of petal length and sepal width are easily seen.

For datasets where feature selection is critical, this visualization (and the stack-rank of features that contribute to the PCA) immediately reveal which features to focus on.

Boston Housing

To me, it’s important that tools are reusable and scale properly with different datasets. To demonstrate how the feature importance visualizations scale with (modestly) larger datasets, let’s take a look at the Boston Housing dataset.

Here’s how each feature in Boston Housing correlates with each principal component:

Z-normalized correlation matrix between the original Boston housing dataset and PC-space.

Once again, we see which features contribute to each PC, but beyond the first few PCs the plot is harder to interpret.

Now let’s look at the version that’s normalized by explained variance:

Z-normalized correlation matrix between the original Boston housing dataset and PC-space, normalized by explained variance.

With this version, the picture becomes much clear: TAX and B contribute the most amount of variance, following by ZN and perhaps AGE.

The Cancer Genome Atlas Breast Cancer Dataset

The Cancer Genome Atlas (TCGA) breast cancer RNA-Seq dataset (I’m using an old freeze from 2015) has 20,532 features (genes for which expression is measured) but only 204 samples of either a primary tumor or normal tissue. This makes for a very fun machine learning problem — can we build a reliable model that uses gene expression to distinguish normal tissue from cancer tissue?

If we try and train a model on all 20,532 features, we’re going to run into many issues. Instead, let’s take a look at whether we can generate some interpretable principal components for visualization and model building.

First, let’s look at the interactive PCA plot with these data:

PCA on the TCGA-BRCA RNA-Seq dataset. It takes only 3 dimensions to explain ~1/3 of the variance in the data.

From this, we see that it takes ‘only’ 129 features to explain 95% of the variance. Importantly, 33.0% of the variance is explained from just three principal components.

Let’s dive deeper. We can also use the heatmaps from before to inspect the data— but have you ever tried to create a 20,532 x 20,532 dimensional heatmap? This kills the kernel. Instead, we can modify the code to cap the number of features/PCs for visualization to a reasonable number (a future, more elegant solution might only generate a heatmap for features/PCs that explain up to a certain amount of variance, say 95%.)

Below is a heatmap showing correlation between the ‘top 20’ features and principal components:

Z-normalized correlation matrix between the original TCGA-BRCA RNA-seq dataset and PC-space, normalized by explained variance. PC1 and PC2 look to be the most meaningful.

In addition, the code also gives us a printout the ‘top’ features, for easier inspection:

Stack rank of 20 features that contribute to PCs that explain the most variance in the TCGA-BRCA RNA-Seq dataset. You see a number of extracellular matrix and basement membrane proteins, which makes sense because ECM remodeling is often observed in breast cancer.

From these, we see that PC1 has strong contributions from four extracellular matrix proteins and PC2 has contributions from the gene EEF1A1. From a molecular biology standpoint, this makes sense — extracellular matrix remodeling is frequently seen in breast cancer [1, 2], and EEF1A1 expression is associated with poor outcomes in certain breast cancer patients [3].

Now let’s step back — we seem to see strong signal in PC1 and PC2. And from the first plot, we see that they explain ~25% of the variance in the data. So what does it look like when we plot PC1 vs. PC2?

2D PCA plot of the TCGA-BRCA RNA-Seq PCA, shown by Primary Tumor samples (red) and Solid Tissue Normal samples (blue). In 2D PCA-space, you can almost draw a line with the naked eye between both groups — strongly suggesting that we can train a classifier with just two principal components.

The above plot shows all Primary Tumor (red) and Solid Tissue Normal samples (blue) along PC1 and PC2. You can almost draw a line with the naked eye between both groups — which strongly suggests that we can successfully train a classifier with just two principal components using a simple linear model.

That’s pretty incredible when you think about it. From over 20,000 genes, we can define two linear, uncorrelated features that explain enough variance in the data to allow us to differentiate between two groups of interest. Further, we already have some indication of genes that help discriminate normal tissue from breast cancer tissue, and could have value for prognosis, diagnosis, or as therapeutic targets.

I hope this first post has been helpful! In Part 2, I’ll shift to interpreting machine learning models and decision boundaries. Keep an eye on Twitter or GitHub for when it drops!

Additional Resources

If you’d like to learn more about principal component analysis and related topics, I suggest the following articles and videos.