Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/statsmodels/graphics/factorplots.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann
4"""
5from statsmodels.compat.python import iterkeys, lrange, iteritems
6import numpy as np
8from statsmodels.graphics.plottools import rainbow
9import statsmodels.graphics.utils as utils
12def interaction_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
13 xlabel=None, ylabel=None, colors=None, markers=None,
14 linestyles=None, legendloc='best', legendtitle=None,
15 **kwargs):
16 """
17 Interaction plot for factor level statistics.
19 Note. If categorial factors are supplied levels will be internally
20 recoded to integers. This ensures matplotlib compatibility. Uses
21 a DataFrame to calculate an `aggregate` statistic for each level of the
22 factor or group given by `trace`.
24 Parameters
25 ----------
26 x : array_like
27 The `x` factor levels constitute the x-axis. If a `pandas.Series` is
28 given its name will be used in `xlabel` if `xlabel` is None.
29 trace : array_like
30 The `trace` factor levels will be drawn as lines in the plot.
31 If `trace` is a `pandas.Series` its name will be used as the
32 `legendtitle` if `legendtitle` is None.
33 response : array_like
34 The reponse or dependent variable. If a `pandas.Series` is given
35 its name will be used in `ylabel` if `ylabel` is None.
36 func : function
37 Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
38 the response variable grouped by the trace levels.
39 ax : axes, optional
40 Matplotlib axes instance
41 plottype : str {'line', 'scatter', 'both'}, optional
42 The type of plot to return. Can be 'l', 's', or 'b'
43 xlabel : str, optional
44 Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
45 will use the series names.
46 ylabel : str, optional
47 Label to use for `response`. Default is 'func of response'. If
48 `response` is a `pandas.Series` it will use the series names.
49 colors : list, optional
50 If given, must have length == number of levels in trace.
51 markers : list, optional
52 If given, must have length == number of levels in trace
53 linestyles : list, optional
54 If given, must have length == number of levels in trace.
55 legendloc : {None, str, int}
56 Location passed to the legend command.
57 legendtitle : {None, str}
58 Title of the legend.
59 **kwargs
60 These will be passed to the plot command used either plot or scatter.
61 If you want to control the overall plotting options, use kwargs.
63 Returns
64 -------
65 Figure
66 The figure given by `ax.figure` or a new instance.
68 Examples
69 --------
70 >>> import numpy as np
71 >>> np.random.seed(12345)
72 >>> weight = np.random.randint(1,4,size=60)
73 >>> duration = np.random.randint(1,3,size=60)
74 >>> days = np.log(np.random.randint(1,30, size=60))
75 >>> fig = interaction_plot(weight, duration, days,
76 ... colors=['red','blue'], markers=['D','^'], ms=10)
77 >>> import matplotlib.pyplot as plt
78 >>> plt.show()
80 .. plot::
82 import numpy as np
83 from statsmodels.graphics.factorplots import interaction_plot
84 np.random.seed(12345)
85 weight = np.random.randint(1,4,size=60)
86 duration = np.random.randint(1,3,size=60)
87 days = np.log(np.random.randint(1,30, size=60))
88 fig = interaction_plot(weight, duration, days,
89 colors=['red','blue'], markers=['D','^'], ms=10)
90 import matplotlib.pyplot as plt
91 #plt.show()
92 """
94 from pandas import DataFrame
95 fig, ax = utils.create_mpl_ax(ax)
97 response_name = ylabel or getattr(response, 'name', 'response')
98 ylabel = '%s of %s' % (func.__name__, response_name)
99 xlabel = xlabel or getattr(x, 'name', 'X')
100 legendtitle = legendtitle or getattr(trace, 'name', 'Trace')
102 ax.set_ylabel(ylabel)
103 ax.set_xlabel(xlabel)
105 x_values = x_levels = None
106 if isinstance(x[0], str):
107 x_levels = [l for l in np.unique(x)]
108 x_values = lrange(len(x_levels))
109 x = _recode(x, dict(zip(x_levels, x_values)))
111 data = DataFrame(dict(x=x, trace=trace, response=response))
112 plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
114 # return data
115 # check plot args
116 n_trace = len(plot_data['trace'].unique())
118 linestyles = ['-'] * n_trace if linestyles is None else linestyles
119 markers = ['.'] * n_trace if markers is None else markers
120 colors = rainbow(n_trace) if colors is None else colors
122 if len(linestyles) != n_trace:
123 raise ValueError("Must be a linestyle for each trace level")
124 if len(markers) != n_trace:
125 raise ValueError("Must be a marker for each trace level")
126 if len(colors) != n_trace:
127 raise ValueError("Must be a color for each trace level")
129 if plottype == 'both' or plottype == 'b':
130 for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
131 # trace label
132 label = str(group['trace'].values[0])
133 ax.plot(group['x'], group['response'], color=colors[i],
134 marker=markers[i], label=label,
135 linestyle=linestyles[i], **kwargs)
136 elif plottype == 'line' or plottype == 'l':
137 for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
138 # trace label
139 label = str(group['trace'].values[0])
140 ax.plot(group['x'], group['response'], color=colors[i],
141 label=label, linestyle=linestyles[i], **kwargs)
142 elif plottype == 'scatter' or plottype == 's':
143 for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
144 # trace label
145 label = str(group['trace'].values[0])
146 ax.scatter(group['x'], group['response'], color=colors[i],
147 label=label, marker=markers[i], **kwargs)
149 else:
150 raise ValueError("Plot type %s not understood" % plottype)
151 ax.legend(loc=legendloc, title=legendtitle)
152 ax.margins(.1)
154 if all([x_levels, x_values]):
155 ax.set_xticks(x_values)
156 ax.set_xticklabels(x_levels)
157 return fig
160def _recode(x, levels):
161 """ Recode categorial data to int factor.
163 Parameters
164 ----------
165 x : array_like
166 array like object supporting with numpy array methods of categorially
167 coded data.
168 levels : dict
169 mapping of labels to integer-codings
171 Returns
172 -------
173 out : instance numpy.ndarray
174 """
175 from pandas import Series
176 name = None
177 index = None
179 if isinstance(x, Series):
180 name = x.name
181 index = x.index
182 x = x.values
184 if x.dtype.type not in [np.str_, np.object_]:
185 raise ValueError('This is not a categorial factor.'
186 ' Array of str type required.')
188 elif not isinstance(levels, dict):
189 raise ValueError('This is not a valid value for levels.'
190 ' Dict required.')
192 elif not (np.unique(x) == np.unique(list(iterkeys(levels)))).all():
193 raise ValueError('The levels do not match the array values.')
195 else:
196 out = np.empty(x.shape[0], dtype=np.int)
197 for level, coding in iteritems(levels):
198 out[x == level] = coding
200 if name:
201 out = Series(out, name=name, index=index)
203 return out