1"""
2=======================================
3Visualizing the stock market structure
4=======================================
5
6This example employs several unsupervised learning techniques to extract
7the stock market structure from variations in historical quotes.
8
9The quantity that we use is the daily variation in quote price: quotes
10that are linked tend to cofluctuate during a day.
11
12.. _stock_market:
13
14Learning a graph structure
15--------------------------
16
17We use sparse inverse covariance estimation to find which quotes are
18correlated conditionally on the others. Specifically, sparse inverse
19covariance gives us a graph, that is a list of connection. For each
20symbol, the symbols that it is connected too are those useful to explain
21its fluctuations.
22
23Clustering
24----------
25
26We use clustering to group together quotes that behave similarly. Here,
27amongst the :ref:`various clustering techniques <clustering>` available
28in the scikit-learn, we use :ref:`affinity_propagation` as it does
29not enforce equal-size clusters, and it can choose automatically the
30number of clusters from the data.
31
32Note that this gives us a different indication than the graph, as the
33graph reflects conditional relations between variables, while the
34clustering reflects marginal properties: variables clustered together can
35be considered as having a similar impact at the level of the full stock
36market.
37
38Embedding in 2D space
39---------------------
40
41For visualization purposes, we need to lay out the different symbols on a
422D canvas. For this we use :ref:`manifold` techniques to retrieve 2D
43embedding.
44
45
46Visualization
47-------------
48
49The output of the 3 models are combined in a 2D graph where nodes
50represents the stocks and edges the:
51
52- cluster labels are used to define the color of the nodes
53- the sparse covariance model is used to display the strength of the edges
54- the 2D embedding is used to position the nodes in the plan
55
56This example has a fair amount of visualization-related code, as
57visualization is crucial here to display the graph. One of the challenge
58is to position the labels minimizing overlap. For this we use an
59heuristic based on the direction of the nearest neighbor along each
60axis.
61
62"""
63
64# Author: Gael Varoquaux gael.varoquaux@normalesup.org
65# License: BSD 3 clause
66
67import sys
68
69import numpy as np
70import matplotlib.pyplot as plt
71from matplotlib.collections import LineCollection
72
73import pandas as pd
74
75from sklearn import cluster, covariance, manifold
76
77
78# #############################################################################
79# Retrieve the data from Internet
80
81# The data is from 2003 - 2008. This is reasonably calm: (not too long ago so
82# that we get high-tech firms, and before the 2008 crash). This kind of
83# historical data can be obtained for from APIs like the quandl.com and
84# alphavantage.co ones.
85
86symbol_dict = {
87    "TOT": "Total",
88    "XOM": "Exxon",
89    "CVX": "Chevron",
90    "COP": "ConocoPhillips",
91    "VLO": "Valero Energy",
92    "MSFT": "Microsoft",
93    "IBM": "IBM",
94    "TWX": "Time Warner",
95    "CMCSA": "Comcast",
96    "CVC": "Cablevision",
97    "YHOO": "Yahoo",
98    "DELL": "Dell",
99    "HPQ": "HP",
100    "AMZN": "Amazon",
101    "TM": "Toyota",
102    "CAJ": "Canon",
103    "SNE": "Sony",
104    "F": "Ford",
105    "HMC": "Honda",
106    "NAV": "Navistar",
107    "NOC": "Northrop Grumman",
108    "BA": "Boeing",
109    "KO": "Coca Cola",
110    "MMM": "3M",
111    "MCD": "McDonald's",
112    "PEP": "Pepsi",
113    "K": "Kellogg",
114    "UN": "Unilever",
115    "MAR": "Marriott",
116    "PG": "Procter Gamble",
117    "CL": "Colgate-Palmolive",
118    "GE": "General Electrics",
119    "WFC": "Wells Fargo",
120    "JPM": "JPMorgan Chase",
121    "AIG": "AIG",
122    "AXP": "American express",
123    "BAC": "Bank of America",
124    "GS": "Goldman Sachs",
125    "AAPL": "Apple",
126    "SAP": "SAP",
127    "CSCO": "Cisco",
128    "TXN": "Texas Instruments",
129    "XRX": "Xerox",
130    "WMT": "Wal-Mart",
131    "HD": "Home Depot",
132    "GSK": "GlaxoSmithKline",
133    "PFE": "Pfizer",
134    "SNY": "Sanofi-Aventis",
135    "NVS": "Novartis",
136    "KMB": "Kimberly-Clark",
137    "R": "Ryder",
138    "GD": "General Dynamics",
139    "RTN": "Raytheon",
140    "CVS": "CVS",
141    "CAT": "Caterpillar",
142    "DD": "DuPont de Nemours",
143}
144
145
146symbols, names = np.array(sorted(symbol_dict.items())).T
147
148quotes = []
149
150for symbol in symbols:
151    print("Fetching quote history for %r" % symbol, file=sys.stderr)
152    url = (
153        "https://raw.githubusercontent.com/scikit-learn/examples-data/"
154        "master/financial-data/{}.csv"
155    )
156    quotes.append(pd.read_csv(url.format(symbol)))
157
158close_prices = np.vstack([q["close"] for q in quotes])
159open_prices = np.vstack([q["open"] for q in quotes])
160
161# The daily variations of the quotes are what carry most information
162variation = close_prices - open_prices
163
164
165# #############################################################################
166# Learn a graphical structure from the correlations
167edge_model = covariance.GraphicalLassoCV()
168
169# standardize the time series: using correlations rather than covariance
170# is more efficient for structure recovery
171X = variation.copy().T
172X /= X.std(axis=0)
173edge_model.fit(X)
174
175# #############################################################################
176# Cluster using affinity propagation
177
178_, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
179n_labels = labels.max()
180
181for i in range(n_labels + 1):
182    print("Cluster %i: %s" % ((i + 1), ", ".join(names[labels == i])))
183
184# #############################################################################
185# Find a low-dimension embedding for visualization: find the best position of
186# the nodes (the stocks) on a 2D plane
187
188# We use a dense eigen_solver to achieve reproducibility (arpack is
189# initiated with random vectors that we don't control). In addition, we
190# use a large number of neighbors to capture the large-scale structure.
191node_position_model = manifold.LocallyLinearEmbedding(
192    n_components=2, eigen_solver="dense", n_neighbors=6
193)
194
195embedding = node_position_model.fit_transform(X.T).T
196
197# #############################################################################
198# Visualization
199plt.figure(1, facecolor="w", figsize=(10, 8))
200plt.clf()
201ax = plt.axes([0.0, 0.0, 1.0, 1.0])
202plt.axis("off")
203
204# Display a graph of the partial correlations
205partial_correlations = edge_model.precision_.copy()
206d = 1 / np.sqrt(np.diag(partial_correlations))
207partial_correlations *= d
208partial_correlations *= d[:, np.newaxis]
209non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02
210
211# Plot the nodes using the coordinates of our embedding
212plt.scatter(
213    embedding[0], embedding[1], s=100 * d ** 2, c=labels, cmap=plt.cm.nipy_spectral
214)
215
216# Plot the edges
217start_idx, end_idx = np.where(non_zero)
218# a sequence of (*line0*, *line1*, *line2*), where::
219#            linen = (x0, y0), (x1, y1), ... (xm, ym)
220segments = [
221    [embedding[:, start], embedding[:, stop]] for start, stop in zip(start_idx, end_idx)
222]
223values = np.abs(partial_correlations[non_zero])
224lc = LineCollection(
225    segments, zorder=0, cmap=plt.cm.hot_r, norm=plt.Normalize(0, 0.7 * values.max())
226)
227lc.set_array(values)
228lc.set_linewidths(15 * values)
229ax.add_collection(lc)
230
231# Add a label to each node. The challenge here is that we want to
232# position the labels to avoid overlap with other labels
233for index, (name, label, (x, y)) in enumerate(zip(names, labels, embedding.T)):
234
235    dx = x - embedding[0]
236    dx[index] = 1
237    dy = y - embedding[1]
238    dy[index] = 1
239    this_dx = dx[np.argmin(np.abs(dy))]
240    this_dy = dy[np.argmin(np.abs(dx))]
241    if this_dx > 0:
242        horizontalalignment = "left"
243        x = x + 0.002
244    else:
245        horizontalalignment = "right"
246        x = x - 0.002
247    if this_dy > 0:
248        verticalalignment = "bottom"
249        y = y + 0.002
250    else:
251        verticalalignment = "top"
252        y = y - 0.002
253    plt.text(
254        x,
255        y,
256        name,
257        size=10,
258        horizontalalignment=horizontalalignment,
259        verticalalignment=verticalalignment,
260        bbox=dict(
261            facecolor="w",
262            edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
263            alpha=0.6,
264        ),
265    )
266
267plt.xlim(
268    embedding[0].min() - 0.15 * embedding[0].ptp(),
269    embedding[0].max() + 0.10 * embedding[0].ptp(),
270)
271plt.ylim(
272    embedding[1].min() - 0.03 * embedding[1].ptp(),
273    embedding[1].max() + 0.03 * embedding[1].ptp(),
274)
275
276plt.show()
277