Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pandas/_testing.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import bz2
2from collections import Counter
3from contextlib import contextmanager
4from datetime import datetime
5from functools import wraps
6import gzip
7import os
8from shutil import rmtree
9import string
10import tempfile
11from typing import Any, List, Optional, Union, cast
12import warnings
13import zipfile
15import numpy as np
16from numpy.random import rand, randn
18from pandas._config.localization import ( # noqa:F401
19 can_set_locale,
20 get_locales,
21 set_locale,
22)
24import pandas._libs.testing as _testing
25from pandas._typing import FilePathOrBuffer, FrameOrSeries
26from pandas.compat import _get_lzma_file, _import_lzma
28from pandas.core.dtypes.common import (
29 is_bool,
30 is_categorical_dtype,
31 is_datetime64_dtype,
32 is_datetime64tz_dtype,
33 is_extension_array_dtype,
34 is_interval_dtype,
35 is_list_like,
36 is_number,
37 is_period_dtype,
38 is_sequence,
39 is_timedelta64_dtype,
40 needs_i8_conversion,
41)
42from pandas.core.dtypes.missing import array_equivalent
44import pandas as pd
45from pandas import (
46 Categorical,
47 CategoricalIndex,
48 DataFrame,
49 DatetimeIndex,
50 Index,
51 IntervalIndex,
52 MultiIndex,
53 RangeIndex,
54 Series,
55 bdate_range,
56)
57from pandas.core.algorithms import take_1d
58from pandas.core.arrays import (
59 DatetimeArray,
60 ExtensionArray,
61 IntervalArray,
62 PeriodArray,
63 TimedeltaArray,
64 period_array,
65)
67from pandas.io.common import urlopen
68from pandas.io.formats.printing import pprint_thing
70lzma = _import_lzma()
72N = 30
73K = 4
74_RAISE_NETWORK_ERROR_DEFAULT = False
76# set testing_mode
77_testing_mode_warnings = (DeprecationWarning, ResourceWarning)
80def set_testing_mode():
81 # set the testing mode filters
82 testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None")
83 if "deprecate" in testing_mode:
84 warnings.simplefilter("always", _testing_mode_warnings)
87def reset_testing_mode():
88 # reset the testing mode filters
89 testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None")
90 if "deprecate" in testing_mode:
91 warnings.simplefilter("ignore", _testing_mode_warnings)
94set_testing_mode()
97def reset_display_options():
98 """
99 Reset the display options for printing and representing objects.
100 """
101 pd.reset_option("^display.", silent=True)
104def round_trip_pickle(
105 obj: Any, path: Optional[FilePathOrBuffer] = None
106) -> FrameOrSeries:
107 """
108 Pickle an object and then read it again.
110 Parameters
111 ----------
112 obj : any object
113 The object to pickle and then re-read.
114 path : str, path object or file-like object, default None
115 The path where the pickled object is written and then read.
117 Returns
118 -------
119 pandas object
120 The original object that was pickled and then re-read.
121 """
122 _path = path
123 if _path is None:
124 _path = f"__{rands(10)}__.pickle"
125 with ensure_clean(_path) as temp_path:
126 pd.to_pickle(obj, temp_path)
127 return pd.read_pickle(temp_path)
130def round_trip_pathlib(writer, reader, path: Optional[str] = None):
131 """
132 Write an object to file specified by a pathlib.Path and read it back
134 Parameters
135 ----------
136 writer : callable bound to pandas object
137 IO writing function (e.g. DataFrame.to_csv )
138 reader : callable
139 IO reading function (e.g. pd.read_csv )
140 path : str, default None
141 The path where the object is written and then read.
143 Returns
144 -------
145 pandas object
146 The original object that was serialized and then re-read.
147 """
148 import pytest
150 Path = pytest.importorskip("pathlib").Path
151 if path is None:
152 path = "___pathlib___"
153 with ensure_clean(path) as path:
154 writer(Path(path))
155 obj = reader(Path(path))
156 return obj
159def round_trip_localpath(writer, reader, path: Optional[str] = None):
160 """
161 Write an object to file specified by a py.path LocalPath and read it back.
163 Parameters
164 ----------
165 writer : callable bound to pandas object
166 IO writing function (e.g. DataFrame.to_csv )
167 reader : callable
168 IO reading function (e.g. pd.read_csv )
169 path : str, default None
170 The path where the object is written and then read.
172 Returns
173 -------
174 pandas object
175 The original object that was serialized and then re-read.
176 """
177 import pytest
179 LocalPath = pytest.importorskip("py.path").local
180 if path is None:
181 path = "___localpath___"
182 with ensure_clean(path) as path:
183 writer(LocalPath(path))
184 obj = reader(LocalPath(path))
185 return obj
188@contextmanager
189def decompress_file(path, compression):
190 """
191 Open a compressed file and return a file object.
193 Parameters
194 ----------
195 path : str
196 The path where the file is read from.
198 compression : {'gzip', 'bz2', 'zip', 'xz', None}
199 Name of the decompression to use
201 Returns
202 -------
203 file object
204 """
205 if compression is None:
206 f = open(path, "rb")
207 elif compression == "gzip":
208 f = gzip.open(path, "rb")
209 elif compression == "bz2":
210 f = bz2.BZ2File(path, "rb")
211 elif compression == "xz":
212 f = _get_lzma_file(lzma)(path, "rb")
213 elif compression == "zip":
214 zip_file = zipfile.ZipFile(path)
215 zip_names = zip_file.namelist()
216 if len(zip_names) == 1:
217 f = zip_file.open(zip_names.pop())
218 else:
219 raise ValueError(f"ZIP file {path} error. Only one file per ZIP.")
220 else:
221 raise ValueError(f"Unrecognized compression type: {compression}")
223 try:
224 yield f
225 finally:
226 f.close()
227 if compression == "zip":
228 zip_file.close()
231def write_to_compressed(compression, path, data, dest="test"):
232 """
233 Write data to a compressed file.
235 Parameters
236 ----------
237 compression : {'gzip', 'bz2', 'zip', 'xz'}
238 The compression type to use.
239 path : str
240 The file path to write the data.
241 data : str
242 The data to write.
243 dest : str, default "test"
244 The destination file (for ZIP only)
246 Raises
247 ------
248 ValueError : An invalid compression value was passed in.
249 """
250 if compression == "zip":
251 import zipfile
253 compress_method = zipfile.ZipFile
254 elif compression == "gzip":
255 import gzip
257 compress_method = gzip.GzipFile
258 elif compression == "bz2":
259 import bz2
261 compress_method = bz2.BZ2File
262 elif compression == "xz":
263 compress_method = _get_lzma_file(lzma)
264 else:
265 raise ValueError(f"Unrecognized compression type: {compression}")
267 if compression == "zip":
268 mode = "w"
269 args = (dest, data)
270 method = "writestr"
271 else:
272 mode = "wb"
273 args = (data,)
274 method = "write"
276 with compress_method(path, mode=mode) as f:
277 getattr(f, method)(*args)
280def assert_almost_equal(
281 left,
282 right,
283 check_dtype: Union[bool, str] = "equiv",
284 check_less_precise: Union[bool, int] = False,
285 **kwargs,
286):
287 """
288 Check that the left and right objects are approximately equal.
290 By approximately equal, we refer to objects that are numbers or that
291 contain numbers which may be equivalent to specific levels of precision.
293 Parameters
294 ----------
295 left : object
296 right : object
297 check_dtype : bool or {'equiv'}, default 'equiv'
298 Check dtype if both a and b are the same type. If 'equiv' is passed in,
299 then `RangeIndex` and `Int64Index` are also considered equivalent
300 when doing type checking.
301 check_less_precise : bool or int, default False
302 Specify comparison precision. 5 digits (False) or 3 digits (True)
303 after decimal points are compared. If int, then specify the number
304 of digits to compare.
306 When comparing two numbers, if the first number has magnitude less
307 than 1e-5, we compare the two numbers directly and check whether
308 they are equivalent within the specified precision. Otherwise, we
309 compare the **ratio** of the second number to the first number and
310 check whether it is equivalent to 1 within the specified precision.
311 """
312 if isinstance(left, pd.Index):
313 assert_index_equal(
314 left,
315 right,
316 check_exact=False,
317 exact=check_dtype,
318 check_less_precise=check_less_precise,
319 **kwargs,
320 )
322 elif isinstance(left, pd.Series):
323 assert_series_equal(
324 left,
325 right,
326 check_exact=False,
327 check_dtype=check_dtype,
328 check_less_precise=check_less_precise,
329 **kwargs,
330 )
332 elif isinstance(left, pd.DataFrame):
333 assert_frame_equal(
334 left,
335 right,
336 check_exact=False,
337 check_dtype=check_dtype,
338 check_less_precise=check_less_precise,
339 **kwargs,
340 )
342 else:
343 # Other sequences.
344 if check_dtype:
345 if is_number(left) and is_number(right):
346 # Do not compare numeric classes, like np.float64 and float.
347 pass
348 elif is_bool(left) and is_bool(right):
349 # Do not compare bool classes, like np.bool_ and bool.
350 pass
351 else:
352 if isinstance(left, np.ndarray) or isinstance(right, np.ndarray):
353 obj = "numpy array"
354 else:
355 obj = "Input"
356 assert_class_equal(left, right, obj=obj)
357 _testing.assert_almost_equal(
358 left,
359 right,
360 check_dtype=check_dtype,
361 check_less_precise=check_less_precise,
362 **kwargs,
363 )
366def _check_isinstance(left, right, cls):
367 """
368 Helper method for our assert_* methods that ensures that
369 the two objects being compared have the right type before
370 proceeding with the comparison.
372 Parameters
373 ----------
374 left : The first object being compared.
375 right : The second object being compared.
376 cls : The class type to check against.
378 Raises
379 ------
380 AssertionError : Either `left` or `right` is not an instance of `cls`.
381 """
382 cls_name = cls.__name__
384 if not isinstance(left, cls):
385 raise AssertionError(
386 f"{cls_name} Expected type {cls}, found {type(left)} instead"
387 )
388 if not isinstance(right, cls):
389 raise AssertionError(
390 f"{cls_name} Expected type {cls}, found {type(right)} instead"
391 )
394def assert_dict_equal(left, right, compare_keys: bool = True):
396 _check_isinstance(left, right, dict)
397 _testing.assert_dict_equal(left, right, compare_keys=compare_keys)
400def randbool(size=(), p: float = 0.5):
401 return rand(*size) <= p
404RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
405RANDU_CHARS = np.array(
406 list("".join(map(chr, range(1488, 1488 + 26))) + string.digits),
407 dtype=(np.unicode_, 1),
408)
411def rands_array(nchars, size, dtype="O"):
412 """
413 Generate an array of byte strings.
414 """
415 retval = (
416 np.random.choice(RANDS_CHARS, size=nchars * np.prod(size))
417 .view((np.str_, nchars))
418 .reshape(size)
419 )
420 if dtype is None:
421 return retval
422 else:
423 return retval.astype(dtype)
426def randu_array(nchars, size, dtype="O"):
427 """
428 Generate an array of unicode strings.
429 """
430 retval = (
431 np.random.choice(RANDU_CHARS, size=nchars * np.prod(size))
432 .view((np.unicode_, nchars))
433 .reshape(size)
434 )
435 if dtype is None:
436 return retval
437 else:
438 return retval.astype(dtype)
441def rands(nchars):
442 """
443 Generate one random byte string.
445 See `rands_array` if you want to create an array of random strings.
447 """
448 return "".join(np.random.choice(RANDS_CHARS, nchars))
451def randu(nchars):
452 """
453 Generate one random unicode string.
455 See `randu_array` if you want to create an array of random unicode strings.
457 """
458 return "".join(np.random.choice(RANDU_CHARS, nchars))
461def close(fignum=None):
462 from matplotlib.pyplot import get_fignums, close as _close
464 if fignum is None:
465 for fignum in get_fignums():
466 _close(fignum)
467 else:
468 _close(fignum)
471# -----------------------------------------------------------------------------
472# contextmanager to ensure the file cleanup
475@contextmanager
476def ensure_clean(filename=None, return_filelike=False):
477 """
478 Gets a temporary path and agrees to remove on close.
480 Parameters
481 ----------
482 filename : str (optional)
483 if None, creates a temporary file which is then removed when out of
484 scope. if passed, creates temporary file with filename as ending.
485 return_filelike : bool (default False)
486 if True, returns a file-like which is *always* cleaned. Necessary for
487 savefig and other functions which want to append extensions.
488 """
489 filename = filename or ""
490 fd = None
492 if return_filelike:
493 f = tempfile.TemporaryFile(suffix=filename)
494 try:
495 yield f
496 finally:
497 f.close()
498 else:
499 # don't generate tempfile if using a path with directory specified
500 if len(os.path.dirname(filename)):
501 raise ValueError("Can't pass a qualified name to ensure_clean()")
503 try:
504 fd, filename = tempfile.mkstemp(suffix=filename)
505 except UnicodeEncodeError:
506 import pytest
508 pytest.skip("no unicode file names on this system")
510 try:
511 yield filename
512 finally:
513 try:
514 os.close(fd)
515 except OSError:
516 print(f"Couldn't close file descriptor: {fd} (file: {filename})")
517 try:
518 if os.path.exists(filename):
519 os.remove(filename)
520 except OSError as e:
521 print(f"Exception on removing file: {e}")
524@contextmanager
525def ensure_clean_dir():
526 """
527 Get a temporary directory path and agrees to remove on close.
529 Yields
530 ------
531 Temporary directory path
532 """
533 directory_name = tempfile.mkdtemp(suffix="")
534 try:
535 yield directory_name
536 finally:
537 try:
538 rmtree(directory_name)
539 except OSError:
540 pass
543@contextmanager
544def ensure_safe_environment_variables():
545 """
546 Get a context manager to safely set environment variables
548 All changes will be undone on close, hence environment variables set
549 within this contextmanager will neither persist nor change global state.
550 """
551 saved_environ = dict(os.environ)
552 try:
553 yield
554 finally:
555 os.environ.clear()
556 os.environ.update(saved_environ)
559# -----------------------------------------------------------------------------
560# Comparators
563def equalContents(arr1, arr2) -> bool:
564 """
565 Checks if the set of unique elements of arr1 and arr2 are equivalent.
566 """
567 return frozenset(arr1) == frozenset(arr2)
570def assert_index_equal(
571 left: Index,
572 right: Index,
573 exact: Union[bool, str] = "equiv",
574 check_names: bool = True,
575 check_less_precise: Union[bool, int] = False,
576 check_exact: bool = True,
577 check_categorical: bool = True,
578 obj: str = "Index",
579) -> None:
580 """
581 Check that left and right Index are equal.
583 Parameters
584 ----------
585 left : Index
586 right : Index
587 exact : bool or {'equiv'}, default 'equiv'
588 Whether to check the Index class, dtype and inferred_type
589 are identical. If 'equiv', then RangeIndex can be substituted for
590 Int64Index as well.
591 check_names : bool, default True
592 Whether to check the names attribute.
593 check_less_precise : bool or int, default False
594 Specify comparison precision. Only used when check_exact is False.
595 5 digits (False) or 3 digits (True) after decimal points are compared.
596 If int, then specify the digits to compare.
597 check_exact : bool, default True
598 Whether to compare number exactly.
599 check_categorical : bool, default True
600 Whether to compare internal Categorical exactly.
601 obj : str, default 'Index'
602 Specify object name being compared, internally used to show appropriate
603 assertion message.
604 """
605 __tracebackhide__ = True
607 def _check_types(l, r, obj="Index"):
608 if exact:
609 assert_class_equal(l, r, exact=exact, obj=obj)
611 # Skip exact dtype checking when `check_categorical` is False
612 if check_categorical:
613 assert_attr_equal("dtype", l, r, obj=obj)
615 # allow string-like to have different inferred_types
616 if l.inferred_type in ("string", "unicode"):
617 assert r.inferred_type in ("string", "unicode")
618 else:
619 assert_attr_equal("inferred_type", l, r, obj=obj)
621 def _get_ilevel_values(index, level):
622 # accept level number only
623 unique = index.levels[level]
624 level_codes = index.codes[level]
625 filled = take_1d(unique._values, level_codes, fill_value=unique._na_value)
626 values = unique._shallow_copy(filled, name=index.names[level])
627 return values
629 # instance validation
630 _check_isinstance(left, right, Index)
632 # class / dtype comparison
633 _check_types(left, right, obj=obj)
635 # level comparison
636 if left.nlevels != right.nlevels:
637 msg1 = f"{obj} levels are different"
638 msg2 = f"{left.nlevels}, {left}"
639 msg3 = f"{right.nlevels}, {right}"
640 raise_assert_detail(obj, msg1, msg2, msg3)
642 # length comparison
643 if len(left) != len(right):
644 msg1 = f"{obj} length are different"
645 msg2 = f"{len(left)}, {left}"
646 msg3 = f"{len(right)}, {right}"
647 raise_assert_detail(obj, msg1, msg2, msg3)
649 # MultiIndex special comparison for little-friendly error messages
650 if left.nlevels > 1:
651 left = cast(MultiIndex, left)
652 right = cast(MultiIndex, right)
654 for level in range(left.nlevels):
655 # cannot use get_level_values here because it can change dtype
656 llevel = _get_ilevel_values(left, level)
657 rlevel = _get_ilevel_values(right, level)
659 lobj = f"MultiIndex level [{level}]"
660 assert_index_equal(
661 llevel,
662 rlevel,
663 exact=exact,
664 check_names=check_names,
665 check_less_precise=check_less_precise,
666 check_exact=check_exact,
667 obj=lobj,
668 )
669 # get_level_values may change dtype
670 _check_types(left.levels[level], right.levels[level], obj=obj)
672 # skip exact index checking when `check_categorical` is False
673 if check_exact and check_categorical:
674 if not left.equals(right):
675 diff = np.sum((left.values != right.values).astype(int)) * 100.0 / len(left)
676 msg = f"{obj} values are different ({np.round(diff, 5)} %)"
677 raise_assert_detail(obj, msg, left, right)
678 else:
679 _testing.assert_almost_equal(
680 left.values,
681 right.values,
682 check_less_precise=check_less_precise,
683 check_dtype=exact,
684 obj=obj,
685 lobj=left,
686 robj=right,
687 )
689 # metadata comparison
690 if check_names:
691 assert_attr_equal("names", left, right, obj=obj)
692 if isinstance(left, pd.PeriodIndex) or isinstance(right, pd.PeriodIndex):
693 assert_attr_equal("freq", left, right, obj=obj)
694 if isinstance(left, pd.IntervalIndex) or isinstance(right, pd.IntervalIndex):
695 assert_interval_array_equal(left.values, right.values)
697 if check_categorical:
698 if is_categorical_dtype(left) or is_categorical_dtype(right):
699 assert_categorical_equal(left.values, right.values, obj=f"{obj} category")
702def assert_class_equal(left, right, exact: Union[bool, str] = True, obj="Input"):
703 """
704 Checks classes are equal.
705 """
706 __tracebackhide__ = True
708 def repr_class(x):
709 if isinstance(x, Index):
710 # return Index as it is to include values in the error message
711 return x
713 try:
714 return type(x).__name__
715 except AttributeError:
716 return repr(type(x))
718 if exact == "equiv":
719 if type(left) != type(right):
720 # allow equivalence of Int64Index/RangeIndex
721 types = {type(left).__name__, type(right).__name__}
722 if len(types - {"Int64Index", "RangeIndex"}):
723 msg = f"{obj} classes are not equivalent"
724 raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
725 elif exact:
726 if type(left) != type(right):
727 msg = f"{obj} classes are different"
728 raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
731def assert_attr_equal(attr, left, right, obj="Attributes"):
732 """checks attributes are equal. Both objects must have attribute.
734 Parameters
735 ----------
736 attr : str
737 Attribute name being compared.
738 left : object
739 right : object
740 obj : str, default 'Attributes'
741 Specify object name being compared, internally used to show appropriate
742 assertion message
743 """
744 __tracebackhide__ = True
746 left_attr = getattr(left, attr)
747 right_attr = getattr(right, attr)
749 if left_attr is right_attr:
750 return True
751 elif (
752 is_number(left_attr)
753 and np.isnan(left_attr)
754 and is_number(right_attr)
755 and np.isnan(right_attr)
756 ):
757 # np.nan
758 return True
760 try:
761 result = left_attr == right_attr
762 except TypeError:
763 # datetimetz on rhs may raise TypeError
764 result = False
765 if not isinstance(result, bool):
766 result = result.all()
768 if result:
769 return True
770 else:
771 msg = f'Attribute "{attr}" are different'
772 raise_assert_detail(obj, msg, left_attr, right_attr)
775def assert_is_valid_plot_return_object(objs):
776 import matplotlib.pyplot as plt
778 if isinstance(objs, (pd.Series, np.ndarray)):
779 for el in objs.ravel():
780 msg = (
781 "one of 'objs' is not a matplotlib Axes instance, "
782 f"type encountered {repr(type(el).__name__)}"
783 )
784 assert isinstance(el, (plt.Axes, dict)), msg
785 else:
786 msg = (
787 "objs is neither an ndarray of Artist instances nor a single "
788 "ArtistArtist instance, tuple, or dict, 'objs' is a "
789 f"{repr(type(objs).__name__)}"
790 )
791 assert isinstance(objs, (plt.Artist, tuple, dict)), msg
794def isiterable(obj):
795 return hasattr(obj, "__iter__")
798def assert_is_sorted(seq):
799 """Assert that the sequence is sorted."""
800 if isinstance(seq, (Index, Series)):
801 seq = seq.values
802 # sorting does not change precisions
803 assert_numpy_array_equal(seq, np.sort(np.array(seq)))
806def assert_categorical_equal(
807 left, right, check_dtype=True, check_category_order=True, obj="Categorical"
808):
809 """Test that Categoricals are equivalent.
811 Parameters
812 ----------
813 left : Categorical
814 right : Categorical
815 check_dtype : bool, default True
816 Check that integer dtype of the codes are the same
817 check_category_order : bool, default True
818 Whether the order of the categories should be compared, which
819 implies identical integer codes. If False, only the resulting
820 values are compared. The ordered attribute is
821 checked regardless.
822 obj : str, default 'Categorical'
823 Specify object name being compared, internally used to show appropriate
824 assertion message
825 """
826 _check_isinstance(left, right, Categorical)
828 if check_category_order:
829 assert_index_equal(left.categories, right.categories, obj=f"{obj}.categories")
830 assert_numpy_array_equal(
831 left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes",
832 )
833 else:
834 assert_index_equal(
835 left.categories.sort_values(),
836 right.categories.sort_values(),
837 obj=f"{obj}.categories",
838 )
839 assert_index_equal(
840 left.categories.take(left.codes),
841 right.categories.take(right.codes),
842 obj=f"{obj}.values",
843 )
845 assert_attr_equal("ordered", left, right, obj=obj)
848def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
849 """Test that two IntervalArrays are equivalent.
851 Parameters
852 ----------
853 left, right : IntervalArray
854 The IntervalArrays to compare.
855 exact : bool or {'equiv'}, default 'equiv'
856 Whether to check the Index class, dtype and inferred_type
857 are identical. If 'equiv', then RangeIndex can be substituted for
858 Int64Index as well.
859 obj : str, default 'IntervalArray'
860 Specify object name being compared, internally used to show appropriate
861 assertion message
862 """
863 _check_isinstance(left, right, IntervalArray)
865 assert_index_equal(left.left, right.left, exact=exact, obj=f"{obj}.left")
866 assert_index_equal(left.right, right.right, exact=exact, obj=f"{obj}.left")
867 assert_attr_equal("closed", left, right, obj=obj)
870def assert_period_array_equal(left, right, obj="PeriodArray"):
871 _check_isinstance(left, right, PeriodArray)
873 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}.values")
874 assert_attr_equal("freq", left, right, obj=obj)
877def assert_datetime_array_equal(left, right, obj="DatetimeArray"):
878 __tracebackhide__ = True
879 _check_isinstance(left, right, DatetimeArray)
881 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
882 assert_attr_equal("freq", left, right, obj=obj)
883 assert_attr_equal("tz", left, right, obj=obj)
886def assert_timedelta_array_equal(left, right, obj="TimedeltaArray"):
887 __tracebackhide__ = True
888 _check_isinstance(left, right, TimedeltaArray)
889 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
890 assert_attr_equal("freq", left, right, obj=obj)
893def raise_assert_detail(obj, message, left, right, diff=None):
894 __tracebackhide__ = True
896 if isinstance(left, np.ndarray):
897 left = pprint_thing(left)
898 elif is_categorical_dtype(left):
899 left = repr(left)
901 if isinstance(right, np.ndarray):
902 right = pprint_thing(right)
903 elif is_categorical_dtype(right):
904 right = repr(right)
906 msg = f"""{obj} are different
908{message}
909[left]: {left}
910[right]: {right}"""
912 if diff is not None:
913 msg += f"\n[diff]: {diff}"
915 raise AssertionError(msg)
918def assert_numpy_array_equal(
919 left,
920 right,
921 strict_nan=False,
922 check_dtype=True,
923 err_msg=None,
924 check_same=None,
925 obj="numpy array",
926):
927 """
928 Check that 'np.ndarray' is equivalent.
930 Parameters
931 ----------
932 left, right : numpy.ndarray or iterable
933 The two arrays to be compared.
934 strict_nan : bool, default False
935 If True, consider NaN and None to be different.
936 check_dtype : bool, default True
937 Check dtype if both a and b are np.ndarray.
938 err_msg : str, default None
939 If provided, used as assertion message.
940 check_same : None|'copy'|'same', default None
941 Ensure left and right refer/do not refer to the same memory area.
942 obj : str, default 'numpy array'
943 Specify object name being compared, internally used to show appropriate
944 assertion message.
945 """
946 __tracebackhide__ = True
948 # instance validation
949 # Show a detailed error message when classes are different
950 assert_class_equal(left, right, obj=obj)
951 # both classes must be an np.ndarray
952 _check_isinstance(left, right, np.ndarray)
954 def _get_base(obj):
955 return obj.base if getattr(obj, "base", None) is not None else obj
957 left_base = _get_base(left)
958 right_base = _get_base(right)
960 if check_same == "same":
961 if left_base is not right_base:
962 raise AssertionError(f"{repr(left_base)} is not {repr(right_base)}")
963 elif check_same == "copy":
964 if left_base is right_base:
965 raise AssertionError(f"{repr(left_base)} is {repr(right_base)}")
967 def _raise(left, right, err_msg):
968 if err_msg is None:
969 if left.shape != right.shape:
970 raise_assert_detail(
971 obj, f"{obj} shapes are different", left.shape, right.shape,
972 )
974 diff = 0
975 for l, r in zip(left, right):
976 # count up differences
977 if not array_equivalent(l, r, strict_nan=strict_nan):
978 diff += 1
980 diff = diff * 100.0 / left.size
981 msg = f"{obj} values are different ({np.round(diff, 5)} %)"
982 raise_assert_detail(obj, msg, left, right)
984 raise AssertionError(err_msg)
986 # compare shape and values
987 if not array_equivalent(left, right, strict_nan=strict_nan):
988 _raise(left, right, err_msg)
990 if check_dtype:
991 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
992 assert_attr_equal("dtype", left, right, obj=obj)
995def assert_extension_array_equal(
996 left, right, check_dtype=True, check_less_precise=False, check_exact=False
997):
998 """Check that left and right ExtensionArrays are equal.
1000 Parameters
1001 ----------
1002 left, right : ExtensionArray
1003 The two arrays to compare
1004 check_dtype : bool, default True
1005 Whether to check if the ExtensionArray dtypes are identical.
1006 check_less_precise : bool or int, default False
1007 Specify comparison precision. Only used when check_exact is False.
1008 5 digits (False) or 3 digits (True) after decimal points are compared.
1009 If int, then specify the digits to compare.
1010 check_exact : bool, default False
1011 Whether to compare number exactly.
1013 Notes
1014 -----
1015 Missing values are checked separately from valid values.
1016 A mask of missing values is computed for each and checked to match.
1017 The remaining all-valid values are cast to object dtype and checked.
1018 """
1019 assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
1020 assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
1021 if check_dtype:
1022 assert_attr_equal("dtype", left, right, obj="ExtensionArray")
1024 if hasattr(left, "asi8") and type(right) == type(left):
1025 # Avoid slow object-dtype comparisons
1026 assert_numpy_array_equal(left.asi8, right.asi8)
1027 return
1029 left_na = np.asarray(left.isna())
1030 right_na = np.asarray(right.isna())
1031 assert_numpy_array_equal(left_na, right_na, obj="ExtensionArray NA mask")
1033 left_valid = np.asarray(left[~left_na].astype(object))
1034 right_valid = np.asarray(right[~right_na].astype(object))
1035 if check_exact:
1036 assert_numpy_array_equal(left_valid, right_valid, obj="ExtensionArray")
1037 else:
1038 _testing.assert_almost_equal(
1039 left_valid,
1040 right_valid,
1041 check_dtype=check_dtype,
1042 check_less_precise=check_less_precise,
1043 obj="ExtensionArray",
1044 )
1047# This could be refactored to use the NDFrame.equals method
1048def assert_series_equal(
1049 left,
1050 right,
1051 check_dtype=True,
1052 check_index_type="equiv",
1053 check_series_type=True,
1054 check_less_precise=False,
1055 check_names=True,
1056 check_exact=False,
1057 check_datetimelike_compat=False,
1058 check_categorical=True,
1059 check_category_order=True,
1060 obj="Series",
1061):
1062 """
1063 Check that left and right Series are equal.
1065 Parameters
1066 ----------
1067 left : Series
1068 right : Series
1069 check_dtype : bool, default True
1070 Whether to check the Series dtype is identical.
1071 check_index_type : bool or {'equiv'}, default 'equiv'
1072 Whether to check the Index class, dtype and inferred_type
1073 are identical.
1074 check_series_type : bool, default True
1075 Whether to check the Series class is identical.
1076 check_less_precise : bool or int, default False
1077 Specify comparison precision. Only used when check_exact is False.
1078 5 digits (False) or 3 digits (True) after decimal points are compared.
1079 If int, then specify the digits to compare.
1081 When comparing two numbers, if the first number has magnitude less
1082 than 1e-5, we compare the two numbers directly and check whether
1083 they are equivalent within the specified precision. Otherwise, we
1084 compare the **ratio** of the second number to the first number and
1085 check whether it is equivalent to 1 within the specified precision.
1086 check_names : bool, default True
1087 Whether to check the Series and Index names attribute.
1088 check_exact : bool, default False
1089 Whether to compare number exactly.
1090 check_datetimelike_compat : bool, default False
1091 Compare datetime-like which is comparable ignoring dtype.
1092 check_categorical : bool, default True
1093 Whether to compare internal Categorical exactly.
1094 check_category_order : bool, default True
1095 Whether to compare category order of internal Categoricals
1097 .. versionadded:: 1.0.2
1098 obj : str, default 'Series'
1099 Specify object name being compared, internally used to show appropriate
1100 assertion message.
1101 """
1102 __tracebackhide__ = True
1104 # instance validation
1105 _check_isinstance(left, right, Series)
1107 if check_series_type:
1108 # ToDo: There are some tests using rhs is sparse
1109 # lhs is dense. Should use assert_class_equal in future
1110 assert isinstance(left, type(right))
1111 # assert_class_equal(left, right, obj=obj)
1113 # length comparison
1114 if len(left) != len(right):
1115 msg1 = f"{len(left)}, {left.index}"
1116 msg2 = f"{len(right)}, {right.index}"
1117 raise_assert_detail(obj, "Series length are different", msg1, msg2)
1119 # index comparison
1120 assert_index_equal(
1121 left.index,
1122 right.index,
1123 exact=check_index_type,
1124 check_names=check_names,
1125 check_less_precise=check_less_precise,
1126 check_exact=check_exact,
1127 check_categorical=check_categorical,
1128 obj=f"{obj}.index",
1129 )
1131 if check_dtype:
1132 # We want to skip exact dtype checking when `check_categorical`
1133 # is False. We'll still raise if only one is a `Categorical`,
1134 # regardless of `check_categorical`
1135 if (
1136 is_categorical_dtype(left)
1137 and is_categorical_dtype(right)
1138 and not check_categorical
1139 ):
1140 pass
1141 else:
1142 assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
1144 if check_exact:
1145 assert_numpy_array_equal(
1146 left._internal_get_values(),
1147 right._internal_get_values(),
1148 check_dtype=check_dtype,
1149 obj=str(obj),
1150 )
1151 elif check_datetimelike_compat:
1152 # we want to check only if we have compat dtypes
1153 # e.g. integer and M|m are NOT compat, but we can simply check
1154 # the values in that case
1155 if needs_i8_conversion(left) or needs_i8_conversion(right):
1157 # datetimelike may have different objects (e.g. datetime.datetime
1158 # vs Timestamp) but will compare equal
1159 if not Index(left.values).equals(Index(right.values)):
1160 msg = (
1161 f"[datetimelike_compat=True] {left.values} "
1162 f"is not equal to {right.values}."
1163 )
1164 raise AssertionError(msg)
1165 else:
1166 assert_numpy_array_equal(
1167 left._internal_get_values(),
1168 right._internal_get_values(),
1169 check_dtype=check_dtype,
1170 )
1171 elif is_interval_dtype(left) or is_interval_dtype(right):
1172 assert_interval_array_equal(left.array, right.array)
1173 elif is_extension_array_dtype(left.dtype) and is_datetime64tz_dtype(left.dtype):
1174 # .values is an ndarray, but ._values is the ExtensionArray.
1175 # TODO: Use .array
1176 assert is_extension_array_dtype(right.dtype)
1177 assert_extension_array_equal(left._values, right._values)
1178 elif (
1179 is_extension_array_dtype(left)
1180 and not is_categorical_dtype(left)
1181 and is_extension_array_dtype(right)
1182 and not is_categorical_dtype(right)
1183 ):
1184 assert_extension_array_equal(left.array, right.array)
1185 else:
1186 _testing.assert_almost_equal(
1187 left._internal_get_values(),
1188 right._internal_get_values(),
1189 check_less_precise=check_less_precise,
1190 check_dtype=check_dtype,
1191 obj=str(obj),
1192 )
1194 # metadata comparison
1195 if check_names:
1196 assert_attr_equal("name", left, right, obj=obj)
1198 if check_categorical:
1199 if is_categorical_dtype(left) or is_categorical_dtype(right):
1200 assert_categorical_equal(
1201 left.values,
1202 right.values,
1203 obj=f"{obj} category",
1204 check_category_order=check_category_order,
1205 )
1208# This could be refactored to use the NDFrame.equals method
1209def assert_frame_equal(
1210 left,
1211 right,
1212 check_dtype=True,
1213 check_index_type="equiv",
1214 check_column_type="equiv",
1215 check_frame_type=True,
1216 check_less_precise=False,
1217 check_names=True,
1218 by_blocks=False,
1219 check_exact=False,
1220 check_datetimelike_compat=False,
1221 check_categorical=True,
1222 check_like=False,
1223 obj="DataFrame",
1224):
1225 """
1226 Check that left and right DataFrame are equal.
1228 This function is intended to compare two DataFrames and output any
1229 differences. Is is mostly intended for use in unit tests.
1230 Additional parameters allow varying the strictness of the
1231 equality checks performed.
1233 Parameters
1234 ----------
1235 left : DataFrame
1236 First DataFrame to compare.
1237 right : DataFrame
1238 Second DataFrame to compare.
1239 check_dtype : bool, default True
1240 Whether to check the DataFrame dtype is identical.
1241 check_index_type : bool or {'equiv'}, default 'equiv'
1242 Whether to check the Index class, dtype and inferred_type
1243 are identical.
1244 check_column_type : bool or {'equiv'}, default 'equiv'
1245 Whether to check the columns class, dtype and inferred_type
1246 are identical. Is passed as the ``exact`` argument of
1247 :func:`assert_index_equal`.
1248 check_frame_type : bool, default True
1249 Whether to check the DataFrame class is identical.
1250 check_less_precise : bool or int, default False
1251 Specify comparison precision. Only used when check_exact is False.
1252 5 digits (False) or 3 digits (True) after decimal points are compared.
1253 If int, then specify the digits to compare.
1255 When comparing two numbers, if the first number has magnitude less
1256 than 1e-5, we compare the two numbers directly and check whether
1257 they are equivalent within the specified precision. Otherwise, we
1258 compare the **ratio** of the second number to the first number and
1259 check whether it is equivalent to 1 within the specified precision.
1260 check_names : bool, default True
1261 Whether to check that the `names` attribute for both the `index`
1262 and `column` attributes of the DataFrame is identical.
1263 by_blocks : bool, default False
1264 Specify how to compare internal data. If False, compare by columns.
1265 If True, compare by blocks.
1266 check_exact : bool, default False
1267 Whether to compare number exactly.
1268 check_datetimelike_compat : bool, default False
1269 Compare datetime-like which is comparable ignoring dtype.
1270 check_categorical : bool, default True
1271 Whether to compare internal Categorical exactly.
1272 check_like : bool, default False
1273 If True, ignore the order of index & columns.
1274 Note: index labels must match their respective rows
1275 (same as in columns) - same labels must be with the same data.
1276 obj : str, default 'DataFrame'
1277 Specify object name being compared, internally used to show appropriate
1278 assertion message.
1280 See Also
1281 --------
1282 assert_series_equal : Equivalent method for asserting Series equality.
1283 DataFrame.equals : Check DataFrame equality.
1285 Examples
1286 --------
1287 This example shows comparing two DataFrames that are equal
1288 but with columns of differing dtypes.
1290 >>> from pandas._testing import assert_frame_equal
1291 >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
1292 >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]})
1294 df1 equals itself.
1296 >>> assert_frame_equal(df1, df1)
1298 df1 differs from df2 as column 'b' is of a different type.
1300 >>> assert_frame_equal(df1, df2)
1301 Traceback (most recent call last):
1302 ...
1303 AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different
1305 Attribute "dtype" are different
1306 [left]: int64
1307 [right]: float64
1309 Ignore differing dtypes in columns with check_dtype.
1311 >>> assert_frame_equal(df1, df2, check_dtype=False)
1312 """
1313 __tracebackhide__ = True
1315 # instance validation
1316 _check_isinstance(left, right, DataFrame)
1318 if check_frame_type:
1319 assert isinstance(left, type(right))
1320 # assert_class_equal(left, right, obj=obj)
1322 # shape comparison
1323 if left.shape != right.shape:
1324 raise_assert_detail(
1325 obj, f"{obj} shape mismatch", f"{repr(left.shape)}", f"{repr(right.shape)}",
1326 )
1328 if check_like:
1329 left, right = left.reindex_like(right), right
1331 # index comparison
1332 assert_index_equal(
1333 left.index,
1334 right.index,
1335 exact=check_index_type,
1336 check_names=check_names,
1337 check_less_precise=check_less_precise,
1338 check_exact=check_exact,
1339 check_categorical=check_categorical,
1340 obj=f"{obj}.index",
1341 )
1343 # column comparison
1344 assert_index_equal(
1345 left.columns,
1346 right.columns,
1347 exact=check_column_type,
1348 check_names=check_names,
1349 check_less_precise=check_less_precise,
1350 check_exact=check_exact,
1351 check_categorical=check_categorical,
1352 obj=f"{obj}.columns",
1353 )
1355 # compare by blocks
1356 if by_blocks:
1357 rblocks = right._to_dict_of_blocks()
1358 lblocks = left._to_dict_of_blocks()
1359 for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))):
1360 assert dtype in lblocks
1361 assert dtype in rblocks
1362 assert_frame_equal(
1363 lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj
1364 )
1366 # compare by columns
1367 else:
1368 for i, col in enumerate(left.columns):
1369 assert col in right
1370 lcol = left.iloc[:, i]
1371 rcol = right.iloc[:, i]
1372 assert_series_equal(
1373 lcol,
1374 rcol,
1375 check_dtype=check_dtype,
1376 check_index_type=check_index_type,
1377 check_less_precise=check_less_precise,
1378 check_exact=check_exact,
1379 check_names=check_names,
1380 check_datetimelike_compat=check_datetimelike_compat,
1381 check_categorical=check_categorical,
1382 obj=f'{obj}.iloc[:, {i}] (column name="{col}")',
1383 )
1386def assert_equal(left, right, **kwargs):
1387 """
1388 Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.
1390 Parameters
1391 ----------
1392 left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray
1393 The two items to be compared.
1394 **kwargs
1395 All keyword arguments are passed through to the underlying assert method.
1396 """
1397 __tracebackhide__ = True
1399 if isinstance(left, pd.Index):
1400 assert_index_equal(left, right, **kwargs)
1401 elif isinstance(left, pd.Series):
1402 assert_series_equal(left, right, **kwargs)
1403 elif isinstance(left, pd.DataFrame):
1404 assert_frame_equal(left, right, **kwargs)
1405 elif isinstance(left, IntervalArray):
1406 assert_interval_array_equal(left, right, **kwargs)
1407 elif isinstance(left, PeriodArray):
1408 assert_period_array_equal(left, right, **kwargs)
1409 elif isinstance(left, DatetimeArray):
1410 assert_datetime_array_equal(left, right, **kwargs)
1411 elif isinstance(left, TimedeltaArray):
1412 assert_timedelta_array_equal(left, right, **kwargs)
1413 elif isinstance(left, ExtensionArray):
1414 assert_extension_array_equal(left, right, **kwargs)
1415 elif isinstance(left, np.ndarray):
1416 assert_numpy_array_equal(left, right, **kwargs)
1417 elif isinstance(left, str):
1418 assert kwargs == {}
1419 assert left == right
1420 else:
1421 raise NotImplementedError(type(left))
1424def box_expected(expected, box_cls, transpose=True):
1425 """
1426 Helper function to wrap the expected output of a test in a given box_class.
1428 Parameters
1429 ----------
1430 expected : np.ndarray, Index, Series
1431 box_cls : {Index, Series, DataFrame}
1433 Returns
1434 -------
1435 subclass of box_cls
1436 """
1437 if box_cls is pd.Index:
1438 expected = pd.Index(expected)
1439 elif box_cls is pd.Series:
1440 expected = pd.Series(expected)
1441 elif box_cls is pd.DataFrame:
1442 expected = pd.Series(expected).to_frame()
1443 if transpose:
1444 # for vector operations, we we need a DataFrame to be a single-row,
1445 # not a single-column, in order to operate against non-DataFrame
1446 # vectors of the same length.
1447 expected = expected.T
1448 elif box_cls is PeriodArray:
1449 # the PeriodArray constructor is not as flexible as period_array
1450 expected = period_array(expected)
1451 elif box_cls is DatetimeArray:
1452 expected = DatetimeArray(expected)
1453 elif box_cls is TimedeltaArray:
1454 expected = TimedeltaArray(expected)
1455 elif box_cls is np.ndarray:
1456 expected = np.array(expected)
1457 elif box_cls is to_array:
1458 expected = to_array(expected)
1459 else:
1460 raise NotImplementedError(box_cls)
1461 return expected
1464def to_array(obj):
1465 # temporary implementation until we get pd.array in place
1466 if is_period_dtype(obj):
1467 return period_array(obj)
1468 elif is_datetime64_dtype(obj) or is_datetime64tz_dtype(obj):
1469 return DatetimeArray._from_sequence(obj)
1470 elif is_timedelta64_dtype(obj):
1471 return TimedeltaArray._from_sequence(obj)
1472 else:
1473 return np.array(obj)
1476# -----------------------------------------------------------------------------
1477# Sparse
1480def assert_sp_array_equal(
1481 left,
1482 right,
1483 check_dtype=True,
1484 check_kind=True,
1485 check_fill_value=True,
1486 consolidate_block_indices=False,
1487):
1488 """Check that the left and right SparseArray are equal.
1490 Parameters
1491 ----------
1492 left : SparseArray
1493 right : SparseArray
1494 check_dtype : bool, default True
1495 Whether to check the data dtype is identical.
1496 check_kind : bool, default True
1497 Whether to just the kind of the sparse index for each column.
1498 check_fill_value : bool, default True
1499 Whether to check that left.fill_value matches right.fill_value
1500 consolidate_block_indices : bool, default False
1501 Whether to consolidate contiguous blocks for sparse arrays with
1502 a BlockIndex. Some operations, e.g. concat, will end up with
1503 block indices that could be consolidated. Setting this to true will
1504 create a new BlockIndex for that array, with consolidated
1505 block indices.
1506 """
1508 _check_isinstance(left, right, pd.arrays.SparseArray)
1510 assert_numpy_array_equal(left.sp_values, right.sp_values, check_dtype=check_dtype)
1512 # SparseIndex comparison
1513 assert isinstance(left.sp_index, pd._libs.sparse.SparseIndex)
1514 assert isinstance(right.sp_index, pd._libs.sparse.SparseIndex)
1516 if not check_kind:
1517 left_index = left.sp_index.to_block_index()
1518 right_index = right.sp_index.to_block_index()
1519 else:
1520 left_index = left.sp_index
1521 right_index = right.sp_index
1523 if consolidate_block_indices and left.kind == "block":
1524 # we'll probably remove this hack...
1525 left_index = left_index.to_int_index().to_block_index()
1526 right_index = right_index.to_int_index().to_block_index()
1528 if not left_index.equals(right_index):
1529 raise_assert_detail(
1530 "SparseArray.index", "index are not equal", left_index, right_index
1531 )
1532 else:
1533 # Just ensure a
1534 pass
1536 if check_fill_value:
1537 assert_attr_equal("fill_value", left, right)
1538 if check_dtype:
1539 assert_attr_equal("dtype", left, right)
1540 assert_numpy_array_equal(left.to_dense(), right.to_dense(), check_dtype=check_dtype)
1543# -----------------------------------------------------------------------------
1544# Others
1547def assert_contains_all(iterable, dic):
1548 for k in iterable:
1549 assert k in dic, f"Did not contain item: {repr(k)}"
1552def assert_copy(iter1, iter2, **eql_kwargs):
1553 """
1554 iter1, iter2: iterables that produce elements
1555 comparable with assert_almost_equal
1557 Checks that the elements are equal, but not
1558 the same object. (Does not check that items
1559 in sequences are also not the same object)
1560 """
1561 for elem1, elem2 in zip(iter1, iter2):
1562 assert_almost_equal(elem1, elem2, **eql_kwargs)
1563 msg = (
1564 f"Expected object {repr(type(elem1))} and object {repr(type(elem2))} to be "
1565 "different objects, but they were the same object."
1566 )
1567 assert elem1 is not elem2, msg
1570def getCols(k):
1571 return string.ascii_uppercase[:k]
1574# make index
1575def makeStringIndex(k=10, name=None):
1576 return Index(rands_array(nchars=10, size=k), name=name)
1579def makeUnicodeIndex(k=10, name=None):
1580 return Index(randu_array(nchars=10, size=k), name=name)
1583def makeCategoricalIndex(k=10, n=3, name=None, **kwargs):
1584 """ make a length k index or n categories """
1585 x = rands_array(nchars=4, size=n)
1586 return CategoricalIndex(
1587 Categorical.from_codes(np.arange(k) % n, categories=x), name=name, **kwargs
1588 )
1591def makeIntervalIndex(k=10, name=None, **kwargs):
1592 """ make a length k IntervalIndex """
1593 x = np.linspace(0, 100, num=(k + 1))
1594 return IntervalIndex.from_breaks(x, name=name, **kwargs)
1597def makeBoolIndex(k=10, name=None):
1598 if k == 1:
1599 return Index([True], name=name)
1600 elif k == 2:
1601 return Index([False, True], name=name)
1602 return Index([False, True] + [False] * (k - 2), name=name)
1605def makeIntIndex(k=10, name=None):
1606 return Index(list(range(k)), name=name)
1609def makeUIntIndex(k=10, name=None):
1610 return Index([2 ** 63 + i for i in range(k)], name=name)
1613def makeRangeIndex(k=10, name=None, **kwargs):
1614 return RangeIndex(0, k, 1, name=name, **kwargs)
1617def makeFloatIndex(k=10, name=None):
1618 values = sorted(np.random.random_sample(k)) - np.random.random_sample(1)
1619 return Index(values * (10 ** np.random.randint(0, 9)), name=name)
1622def makeDateIndex(k=10, freq="B", name=None, **kwargs):
1623 dt = datetime(2000, 1, 1)
1624 dr = bdate_range(dt, periods=k, freq=freq, name=name)
1625 return DatetimeIndex(dr, name=name, **kwargs)
1628def makeTimedeltaIndex(k=10, freq="D", name=None, **kwargs):
1629 return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs)
1632def makePeriodIndex(k=10, name=None, **kwargs):
1633 dt = datetime(2000, 1, 1)
1634 dr = pd.period_range(start=dt, periods=k, freq="B", name=name, **kwargs)
1635 return dr
1638def makeMultiIndex(k=10, names=None, **kwargs):
1639 return MultiIndex.from_product((("foo", "bar"), (1, 2)), names=names, **kwargs)
1642_names = [
1643 "Alice",
1644 "Bob",
1645 "Charlie",
1646 "Dan",
1647 "Edith",
1648 "Frank",
1649 "George",
1650 "Hannah",
1651 "Ingrid",
1652 "Jerry",
1653 "Kevin",
1654 "Laura",
1655 "Michael",
1656 "Norbert",
1657 "Oliver",
1658 "Patricia",
1659 "Quinn",
1660 "Ray",
1661 "Sarah",
1662 "Tim",
1663 "Ursula",
1664 "Victor",
1665 "Wendy",
1666 "Xavier",
1667 "Yvonne",
1668 "Zelda",
1669]
1672def _make_timeseries(start="2000-01-01", end="2000-12-31", freq="1D", seed=None):
1673 """
1674 Make a DataFrame with a DatetimeIndex
1676 Parameters
1677 ----------
1678 start : str or Timestamp, default "2000-01-01"
1679 The start of the index. Passed to date_range with `freq`.
1680 end : str or Timestamp, default "2000-12-31"
1681 The end of the index. Passed to date_range with `freq`.
1682 freq : str or Freq
1683 The frequency to use for the DatetimeIndex
1684 seed : int, optional
1685 The random state seed.
1687 * name : object dtype with string names
1688 * id : int dtype with
1689 * x, y : float dtype
1691 Examples
1692 --------
1693 >>> _make_timeseries()
1694 id name x y
1695 timestamp
1696 2000-01-01 982 Frank 0.031261 0.986727
1697 2000-01-02 1025 Edith -0.086358 -0.032920
1698 2000-01-03 982 Edith 0.473177 0.298654
1699 2000-01-04 1009 Sarah 0.534344 -0.750377
1700 2000-01-05 963 Zelda -0.271573 0.054424
1701 ... ... ... ... ...
1702 2000-12-27 980 Ingrid -0.132333 -0.422195
1703 2000-12-28 972 Frank -0.376007 -0.298687
1704 2000-12-29 1009 Ursula -0.865047 -0.503133
1705 2000-12-30 1000 Hannah -0.063757 -0.507336
1706 2000-12-31 972 Tim -0.869120 0.531685
1707 """
1708 index = pd.date_range(start=start, end=end, freq=freq, name="timestamp")
1709 n = len(index)
1710 state = np.random.RandomState(seed)
1711 columns = {
1712 "name": state.choice(_names, size=n),
1713 "id": state.poisson(1000, size=n),
1714 "x": state.rand(n) * 2 - 1,
1715 "y": state.rand(n) * 2 - 1,
1716 }
1717 df = pd.DataFrame(columns, index=index, columns=sorted(columns))
1718 if df.index[-1] == end:
1719 df = df.iloc[:-1]
1720 return df
1723def all_index_generator(k=10):
1724 """Generator which can be iterated over to get instances of all the various
1725 index classes.
1727 Parameters
1728 ----------
1729 k: length of each of the index instances
1730 """
1731 all_make_index_funcs = [
1732 makeIntIndex,
1733 makeFloatIndex,
1734 makeStringIndex,
1735 makeUnicodeIndex,
1736 makeDateIndex,
1737 makePeriodIndex,
1738 makeTimedeltaIndex,
1739 makeBoolIndex,
1740 makeRangeIndex,
1741 makeIntervalIndex,
1742 makeCategoricalIndex,
1743 ]
1744 for make_index_func in all_make_index_funcs:
1745 yield make_index_func(k=k)
1748def index_subclass_makers_generator():
1749 make_index_funcs = [
1750 makeDateIndex,
1751 makePeriodIndex,
1752 makeTimedeltaIndex,
1753 makeRangeIndex,
1754 makeIntervalIndex,
1755 makeCategoricalIndex,
1756 makeMultiIndex,
1757 ]
1758 for make_index_func in make_index_funcs:
1759 yield make_index_func
1762def all_timeseries_index_generator(k=10):
1763 """Generator which can be iterated over to get instances of all the classes
1764 which represent time-series.
1766 Parameters
1767 ----------
1768 k: length of each of the index instances
1769 """
1770 make_index_funcs = [makeDateIndex, makePeriodIndex, makeTimedeltaIndex]
1771 for make_index_func in make_index_funcs:
1772 yield make_index_func(k=k)
1775# make series
1776def makeFloatSeries(name=None):
1777 index = makeStringIndex(N)
1778 return Series(randn(N), index=index, name=name)
1781def makeStringSeries(name=None):
1782 index = makeStringIndex(N)
1783 return Series(randn(N), index=index, name=name)
1786def makeObjectSeries(name=None):
1787 data = makeStringIndex(N)
1788 data = Index(data, dtype=object)
1789 index = makeStringIndex(N)
1790 return Series(data, index=index, name=name)
1793def getSeriesData():
1794 index = makeStringIndex(N)
1795 return {c: Series(randn(N), index=index) for c in getCols(K)}
1798def makeTimeSeries(nper=None, freq="B", name=None):
1799 if nper is None:
1800 nper = N
1801 return Series(randn(nper), index=makeDateIndex(nper, freq=freq), name=name)
1804def makePeriodSeries(nper=None, name=None):
1805 if nper is None:
1806 nper = N
1807 return Series(randn(nper), index=makePeriodIndex(nper), name=name)
1810def getTimeSeriesData(nper=None, freq="B"):
1811 return {c: makeTimeSeries(nper, freq) for c in getCols(K)}
1814def getPeriodData(nper=None):
1815 return {c: makePeriodSeries(nper) for c in getCols(K)}
1818# make frame
1819def makeTimeDataFrame(nper=None, freq="B"):
1820 data = getTimeSeriesData(nper, freq)
1821 return DataFrame(data)
1824def makeDataFrame():
1825 data = getSeriesData()
1826 return DataFrame(data)
1829def getMixedTypeDict():
1830 index = Index(["a", "b", "c", "d", "e"])
1832 data = {
1833 "A": [0.0, 1.0, 2.0, 3.0, 4.0],
1834 "B": [0.0, 1.0, 0.0, 1.0, 0.0],
1835 "C": ["foo1", "foo2", "foo3", "foo4", "foo5"],
1836 "D": bdate_range("1/1/2009", periods=5),
1837 }
1839 return index, data
1842def makeMixedDataFrame():
1843 return DataFrame(getMixedTypeDict()[1])
1846def makePeriodFrame(nper=None):
1847 data = getPeriodData(nper)
1848 return DataFrame(data)
1851def makeCustomIndex(
1852 nentries, nlevels, prefix="#", names=False, ndupe_l=None, idx_type=None
1853):
1854 """Create an index/multindex with given dimensions, levels, names, etc'
1856 nentries - number of entries in index
1857 nlevels - number of levels (> 1 produces multindex)
1858 prefix - a string prefix for labels
1859 names - (Optional), bool or list of strings. if True will use default
1860 names, if false will use no names, if a list is given, the name of
1861 each level in the index will be taken from the list.
1862 ndupe_l - (Optional), list of ints, the number of rows for which the
1863 label will repeated at the corresponding level, you can specify just
1864 the first few, the rest will use the default ndupe_l of 1.
1865 len(ndupe_l) <= nlevels.
1866 idx_type - "i"/"f"/"s"/"u"/"dt"/"p"/"td".
1867 If idx_type is not None, `idx_nlevels` must be 1.
1868 "i"/"f" creates an integer/float index,
1869 "s"/"u" creates a string/unicode index
1870 "dt" create a datetime index.
1871 "td" create a datetime index.
1873 if unspecified, string labels will be generated.
1874 """
1876 if ndupe_l is None:
1877 ndupe_l = [1] * nlevels
1878 assert is_sequence(ndupe_l) and len(ndupe_l) <= nlevels
1879 assert names is None or names is False or names is True or len(names) is nlevels
1880 assert idx_type is None or (
1881 idx_type in ("i", "f", "s", "u", "dt", "p", "td") and nlevels == 1
1882 )
1884 if names is True:
1885 # build default names
1886 names = [prefix + str(i) for i in range(nlevels)]
1887 if names is False:
1888 # pass None to index constructor for no name
1889 names = None
1891 # make singleton case uniform
1892 if isinstance(names, str) and nlevels == 1:
1893 names = [names]
1895 # specific 1D index type requested?
1896 idx_func = dict(
1897 i=makeIntIndex,
1898 f=makeFloatIndex,
1899 s=makeStringIndex,
1900 u=makeUnicodeIndex,
1901 dt=makeDateIndex,
1902 td=makeTimedeltaIndex,
1903 p=makePeriodIndex,
1904 ).get(idx_type)
1905 if idx_func:
1906 idx = idx_func(nentries)
1907 # but we need to fill in the name
1908 if names:
1909 idx.name = names[0]
1910 return idx
1911 elif idx_type is not None:
1912 raise ValueError(
1913 f"{repr(idx_type)} is not a legal value for `idx_type`, "
1914 "use 'i'/'f'/'s'/'u'/'dt'/'p'/'td'."
1915 )
1917 if len(ndupe_l) < nlevels:
1918 ndupe_l.extend([1] * (nlevels - len(ndupe_l)))
1919 assert len(ndupe_l) == nlevels
1921 assert all(x > 0 for x in ndupe_l)
1923 tuples = []
1924 for i in range(nlevels):
1926 def keyfunc(x):
1927 import re
1929 numeric_tuple = re.sub(r"[^\d_]_?", "", x).split("_")
1930 return [int(num) for num in numeric_tuple]
1932 # build a list of lists to create the index from
1933 div_factor = nentries // ndupe_l[i] + 1
1934 cnt = Counter()
1935 for j in range(div_factor):
1936 label = f"{prefix}_l{i}_g{j}"
1937 cnt[label] = ndupe_l[i]
1938 # cute Counter trick
1939 result = sorted(cnt.elements(), key=keyfunc)[:nentries]
1940 tuples.append(result)
1942 tuples = list(zip(*tuples))
1944 # convert tuples to index
1945 if nentries == 1:
1946 # we have a single level of tuples, i.e. a regular Index
1947 index = Index(tuples[0], name=names[0])
1948 elif nlevels == 1:
1949 name = None if names is None else names[0]
1950 index = Index((x[0] for x in tuples), name=name)
1951 else:
1952 index = MultiIndex.from_tuples(tuples, names=names)
1953 return index
1956def makeCustomDataframe(
1957 nrows,
1958 ncols,
1959 c_idx_names=True,
1960 r_idx_names=True,
1961 c_idx_nlevels=1,
1962 r_idx_nlevels=1,
1963 data_gen_f=None,
1964 c_ndupe_l=None,
1965 r_ndupe_l=None,
1966 dtype=None,
1967 c_idx_type=None,
1968 r_idx_type=None,
1969):
1970 """
1971 nrows, ncols - number of data rows/cols
1972 c_idx_names, idx_names - False/True/list of strings, yields No names ,
1973 default names or uses the provided names for the levels of the
1974 corresponding index. You can provide a single string when
1975 c_idx_nlevels ==1.
1976 c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex
1977 r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex
1978 data_gen_f - a function f(row,col) which return the data value
1979 at that position, the default generator used yields values of the form
1980 "RxCy" based on position.
1981 c_ndupe_l, r_ndupe_l - list of integers, determines the number
1982 of duplicates for each label at a given level of the corresponding
1983 index. The default `None` value produces a multiplicity of 1 across
1984 all levels, i.e. a unique index. Will accept a partial list of length
1985 N < idx_nlevels, for just the first N levels. If ndupe doesn't divide
1986 nrows/ncol, the last label might have lower multiplicity.
1987 dtype - passed to the DataFrame constructor as is, in case you wish to
1988 have more control in conjunction with a custom `data_gen_f`
1989 r_idx_type, c_idx_type - "i"/"f"/"s"/"u"/"dt"/"td".
1990 If idx_type is not None, `idx_nlevels` must be 1.
1991 "i"/"f" creates an integer/float index,
1992 "s"/"u" creates a string/unicode index
1993 "dt" create a datetime index.
1994 "td" create a timedelta index.
1996 if unspecified, string labels will be generated.
1998 Examples:
2000 # 5 row, 3 columns, default names on both, single index on both axis
2001 >> makeCustomDataframe(5,3)
2003 # make the data a random int between 1 and 100
2004 >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100))
2006 # 2-level multiindex on rows with each label duplicated
2007 # twice on first level, default names on both axis, single
2008 # index on both axis
2009 >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2])
2011 # DatetimeIndex on row, index with unicode labels on columns
2012 # no names on either axis
2013 >> a=makeCustomDataframe(5,3,c_idx_names=False,r_idx_names=False,
2014 r_idx_type="dt",c_idx_type="u")
2016 # 4-level multindex on rows with names provided, 2-level multindex
2017 # on columns with default labels and default names.
2018 >> a=makeCustomDataframe(5,3,r_idx_nlevels=4,
2019 r_idx_names=["FEE","FI","FO","FAM"],
2020 c_idx_nlevels=2)
2022 >> a=mkdf(5,3,r_idx_nlevels=2,c_idx_nlevels=4)
2023 """
2025 assert c_idx_nlevels > 0
2026 assert r_idx_nlevels > 0
2027 assert r_idx_type is None or (
2028 r_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and r_idx_nlevels == 1
2029 )
2030 assert c_idx_type is None or (
2031 c_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and c_idx_nlevels == 1
2032 )
2034 columns = makeCustomIndex(
2035 ncols,
2036 nlevels=c_idx_nlevels,
2037 prefix="C",
2038 names=c_idx_names,
2039 ndupe_l=c_ndupe_l,
2040 idx_type=c_idx_type,
2041 )
2042 index = makeCustomIndex(
2043 nrows,
2044 nlevels=r_idx_nlevels,
2045 prefix="R",
2046 names=r_idx_names,
2047 ndupe_l=r_ndupe_l,
2048 idx_type=r_idx_type,
2049 )
2051 # by default, generate data based on location
2052 if data_gen_f is None:
2053 data_gen_f = lambda r, c: f"R{r}C{c}"
2055 data = [[data_gen_f(r, c) for c in range(ncols)] for r in range(nrows)]
2057 return DataFrame(data, index, columns, dtype=dtype)
2060def _create_missing_idx(nrows, ncols, density, random_state=None):
2061 if random_state is None:
2062 random_state = np.random
2063 else:
2064 random_state = np.random.RandomState(random_state)
2066 # below is cribbed from scipy.sparse
2067 size = int(np.round((1 - density) * nrows * ncols))
2068 # generate a few more to ensure unique values
2069 min_rows = 5
2070 fac = 1.02
2071 extra_size = min(size + min_rows, fac * size)
2073 def _gen_unique_rand(rng, _extra_size):
2074 ind = rng.rand(int(_extra_size))
2075 return np.unique(np.floor(ind * nrows * ncols))[:size]
2077 ind = _gen_unique_rand(random_state, extra_size)
2078 while ind.size < size:
2079 extra_size *= 1.05
2080 ind = _gen_unique_rand(random_state, extra_size)
2082 j = np.floor(ind * 1.0 / nrows).astype(int)
2083 i = (ind - j * nrows).astype(int)
2084 return i.tolist(), j.tolist()
2087def makeMissingCustomDataframe(
2088 nrows,
2089 ncols,
2090 density=0.9,
2091 random_state=None,
2092 c_idx_names=True,
2093 r_idx_names=True,
2094 c_idx_nlevels=1,
2095 r_idx_nlevels=1,
2096 data_gen_f=None,
2097 c_ndupe_l=None,
2098 r_ndupe_l=None,
2099 dtype=None,
2100 c_idx_type=None,
2101 r_idx_type=None,
2102):
2103 """
2104 Parameters
2105 ----------
2106 Density : float, optional
2107 Float in (0, 1) that gives the percentage of non-missing numbers in
2108 the DataFrame.
2109 random_state : {np.random.RandomState, int}, optional
2110 Random number generator or random seed.
2112 See makeCustomDataframe for descriptions of the rest of the parameters.
2113 """
2114 df = makeCustomDataframe(
2115 nrows,
2116 ncols,
2117 c_idx_names=c_idx_names,
2118 r_idx_names=r_idx_names,
2119 c_idx_nlevels=c_idx_nlevels,
2120 r_idx_nlevels=r_idx_nlevels,
2121 data_gen_f=data_gen_f,
2122 c_ndupe_l=c_ndupe_l,
2123 r_ndupe_l=r_ndupe_l,
2124 dtype=dtype,
2125 c_idx_type=c_idx_type,
2126 r_idx_type=r_idx_type,
2127 )
2129 i, j = _create_missing_idx(nrows, ncols, density, random_state)
2130 df.values[i, j] = np.nan
2131 return df
2134def makeMissingDataframe(density=0.9, random_state=None):
2135 df = makeDataFrame()
2136 i, j = _create_missing_idx(*df.shape, density=density, random_state=random_state)
2137 df.values[i, j] = np.nan
2138 return df
2141def optional_args(decorator):
2142 """allows a decorator to take optional positional and keyword arguments.
2143 Assumes that taking a single, callable, positional argument means that
2144 it is decorating a function, i.e. something like this::
2146 @my_decorator
2147 def function(): pass
2149 Calls decorator with decorator(f, *args, **kwargs)"""
2151 @wraps(decorator)
2152 def wrapper(*args, **kwargs):
2153 def dec(f):
2154 return decorator(f, *args, **kwargs)
2156 is_decorating = not kwargs and len(args) == 1 and callable(args[0])
2157 if is_decorating:
2158 f = args[0]
2159 args = []
2160 return dec(f)
2161 else:
2162 return dec
2164 return wrapper
2167# skip tests on exceptions with this message
2168_network_error_messages = (
2169 # 'urlopen error timed out',
2170 # 'timeout: timed out',
2171 # 'socket.timeout: timed out',
2172 "timed out",
2173 "Server Hangup",
2174 "HTTP Error 503: Service Unavailable",
2175 "502: Proxy Error",
2176 "HTTP Error 502: internal error",
2177 "HTTP Error 502",
2178 "HTTP Error 503",
2179 "HTTP Error 403",
2180 "HTTP Error 400",
2181 "Temporary failure in name resolution",
2182 "Name or service not known",
2183 "Connection refused",
2184 "certificate verify",
2185)
2187# or this e.errno/e.reason.errno
2188_network_errno_vals = (
2189 101, # Network is unreachable
2190 111, # Connection refused
2191 110, # Connection timed out
2192 104, # Connection reset Error
2193 54, # Connection reset by peer
2194 60, # urllib.error.URLError: [Errno 60] Connection timed out
2195)
2197# Both of the above shouldn't mask real issues such as 404's
2198# or refused connections (changed DNS).
2199# But some tests (test_data yahoo) contact incredibly flakey
2200# servers.
2202# and conditionally raise on exception types in _get_default_network_errors
2205def _get_default_network_errors():
2206 # Lazy import for http.client because it imports many things from the stdlib
2207 import http.client
2209 return (IOError, http.client.HTTPException, TimeoutError)
2212def can_connect(url, error_classes=None):
2213 """Try to connect to the given url. True if succeeds, False if IOError
2214 raised
2216 Parameters
2217 ----------
2218 url : basestring
2219 The URL to try to connect to
2221 Returns
2222 -------
2223 connectable : bool
2224 Return True if no IOError (unable to connect) or URLError (bad url) was
2225 raised
2226 """
2228 if error_classes is None:
2229 error_classes = _get_default_network_errors()
2231 try:
2232 with urlopen(url):
2233 pass
2234 except error_classes:
2235 return False
2236 else:
2237 return True
2240@optional_args
2241def network(
2242 t,
2243 url="http://www.google.com",
2244 raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT,
2245 check_before_test=False,
2246 error_classes=None,
2247 skip_errnos=_network_errno_vals,
2248 _skip_on_messages=_network_error_messages,
2249):
2250 """
2251 Label a test as requiring network connection and, if an error is
2252 encountered, only raise if it does not find a network connection.
2254 In comparison to ``network``, this assumes an added contract to your test:
2255 you must assert that, under normal conditions, your test will ONLY fail if
2256 it does not have network connectivity.
2258 You can call this in 3 ways: as a standard decorator, with keyword
2259 arguments, or with a positional argument that is the url to check.
2261 Parameters
2262 ----------
2263 t : callable
2264 The test requiring network connectivity.
2265 url : path
2266 The url to test via ``pandas.io.common.urlopen`` to check
2267 for connectivity. Defaults to 'http://www.google.com'.
2268 raise_on_error : bool
2269 If True, never catches errors.
2270 check_before_test : bool
2271 If True, checks connectivity before running the test case.
2272 error_classes : tuple or Exception
2273 error classes to ignore. If not in ``error_classes``, raises the error.
2274 defaults to IOError. Be careful about changing the error classes here.
2275 skip_errnos : iterable of int
2276 Any exception that has .errno or .reason.erno set to one
2277 of these values will be skipped with an appropriate
2278 message.
2279 _skip_on_messages: iterable of string
2280 any exception e for which one of the strings is
2281 a substring of str(e) will be skipped with an appropriate
2282 message. Intended to suppress errors where an errno isn't available.
2284 Notes
2285 -----
2286 * ``raise_on_error`` supercedes ``check_before_test``
2288 Returns
2289 -------
2290 t : callable
2291 The decorated test ``t``, with checks for connectivity errors.
2293 Example
2294 -------
2296 Tests decorated with @network will fail if it's possible to make a network
2297 connection to another URL (defaults to google.com)::
2299 >>> from pandas._testing import network
2300 >>> from pandas.io.common import urlopen
2301 >>> @network
2302 ... def test_network():
2303 ... with urlopen("rabbit://bonanza.com"):
2304 ... pass
2305 Traceback
2306 ...
2307 URLError: <urlopen error unknown url type: rabit>
2309 You can specify alternative URLs::
2311 >>> @network("http://www.yahoo.com")
2312 ... def test_something_with_yahoo():
2313 ... raise IOError("Failure Message")
2314 >>> test_something_with_yahoo()
2315 Traceback (most recent call last):
2316 ...
2317 IOError: Failure Message
2319 If you set check_before_test, it will check the url first and not run the
2320 test on failure::
2322 >>> @network("failing://url.blaher", check_before_test=True)
2323 ... def test_something():
2324 ... print("I ran!")
2325 ... raise ValueError("Failure")
2326 >>> test_something()
2327 Traceback (most recent call last):
2328 ...
2330 Errors not related to networking will always be raised.
2331 """
2332 from pytest import skip
2334 if error_classes is None:
2335 error_classes = _get_default_network_errors()
2337 t.network = True
2339 @wraps(t)
2340 def wrapper(*args, **kwargs):
2341 if check_before_test and not raise_on_error:
2342 if not can_connect(url, error_classes):
2343 skip()
2344 try:
2345 return t(*args, **kwargs)
2346 except Exception as err:
2347 errno = getattr(err, "errno", None)
2348 if not errno and hasattr(errno, "reason"):
2349 errno = getattr(err.reason, "errno", None)
2351 if errno in skip_errnos:
2352 skip(f"Skipping test due to known errno and error {err}")
2354 e_str = str(err)
2356 if any(m.lower() in e_str.lower() for m in _skip_on_messages):
2357 skip(
2358 f"Skipping test because exception message is known and error {err}"
2359 )
2361 if not isinstance(err, error_classes):
2362 raise
2364 if raise_on_error or can_connect(url, error_classes):
2365 raise
2366 else:
2367 skip(f"Skipping test due to lack of connectivity and error {err}")
2369 return wrapper
2372with_connectivity_check = network
2375@contextmanager
2376def assert_produces_warning(
2377 expected_warning=Warning,
2378 filter_level="always",
2379 clear=None,
2380 check_stacklevel=True,
2381 raise_on_extra_warnings=True,
2382):
2383 """
2384 Context manager for running code expected to either raise a specific
2385 warning, or not raise any warnings. Verifies that the code raises the
2386 expected warning, and that it does not raise any other unexpected
2387 warnings. It is basically a wrapper around ``warnings.catch_warnings``.
2389 Parameters
2390 ----------
2391 expected_warning : {Warning, False, None}, default Warning
2392 The type of Exception raised. ``exception.Warning`` is the base
2393 class for all warnings. To check that no warning is returned,
2394 specify ``False`` or ``None``.
2395 filter_level : str or None, default "always"
2396 Specifies whether warnings are ignored, displayed, or turned
2397 into errors.
2398 Valid values are:
2400 * "error" - turns matching warnings into exceptions
2401 * "ignore" - discard the warning
2402 * "always" - always emit a warning
2403 * "default" - print the warning the first time it is generated
2404 from each location
2405 * "module" - print the warning the first time it is generated
2406 from each module
2407 * "once" - print the warning the first time it is generated
2409 clear : str, default None
2410 If not ``None`` then remove any previously raised warnings from
2411 the ``__warningsregistry__`` to ensure that no warning messages are
2412 suppressed by this context manager. If ``None`` is specified,
2413 the ``__warningsregistry__`` keeps track of which warnings have been
2414 shown, and does not show them again.
2415 check_stacklevel : bool, default True
2416 If True, displays the line that called the function containing
2417 the warning to show were the function is called. Otherwise, the
2418 line that implements the function is displayed.
2419 raise_on_extra_warnings : bool, default True
2420 Whether extra warnings not of the type `expected_warning` should
2421 cause the test to fail.
2423 Examples
2424 --------
2425 >>> import warnings
2426 >>> with assert_produces_warning():
2427 ... warnings.warn(UserWarning())
2428 ...
2429 >>> with assert_produces_warning(False):
2430 ... warnings.warn(RuntimeWarning())
2431 ...
2432 Traceback (most recent call last):
2433 ...
2434 AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
2435 >>> with assert_produces_warning(UserWarning):
2436 ... warnings.warn(RuntimeWarning())
2437 Traceback (most recent call last):
2438 ...
2439 AssertionError: Did not see expected warning of class 'UserWarning'.
2441 ..warn:: This is *not* thread-safe.
2442 """
2443 __tracebackhide__ = True
2445 with warnings.catch_warnings(record=True) as w:
2447 if clear is not None:
2448 # make sure that we are clearing these warnings
2449 # if they have happened before
2450 # to guarantee that we will catch them
2451 if not is_list_like(clear):
2452 clear = [clear]
2453 for m in clear:
2454 try:
2455 m.__warningregistry__.clear()
2456 except AttributeError:
2457 # module may not have __warningregistry__
2458 pass
2460 saw_warning = False
2461 warnings.simplefilter(filter_level)
2462 yield w
2463 extra_warnings = []
2465 for actual_warning in w:
2466 if expected_warning and issubclass(
2467 actual_warning.category, expected_warning
2468 ):
2469 saw_warning = True
2471 if check_stacklevel and issubclass(
2472 actual_warning.category, (FutureWarning, DeprecationWarning)
2473 ):
2474 from inspect import getframeinfo, stack
2476 caller = getframeinfo(stack()[2][0])
2477 msg = (
2478 "Warning not set with correct stacklevel. "
2479 f"File where warning is raised: {actual_warning.filename} != "
2480 f"{caller.filename}. Warning message: {actual_warning.message}"
2481 )
2482 assert actual_warning.filename == caller.filename, msg
2483 else:
2484 extra_warnings.append(
2485 (
2486 actual_warning.category.__name__,
2487 actual_warning.message,
2488 actual_warning.filename,
2489 actual_warning.lineno,
2490 )
2491 )
2492 if expected_warning:
2493 msg = (
2494 f"Did not see expected warning of class "
2495 f"{repr(expected_warning.__name__)}"
2496 )
2497 assert saw_warning, msg
2498 if raise_on_extra_warnings and extra_warnings:
2499 raise AssertionError(
2500 f"Caused unexpected warning(s): {repr(extra_warnings)}"
2501 )
2504class RNGContext:
2505 """
2506 Context manager to set the numpy random number generator speed. Returns
2507 to the original value upon exiting the context manager.
2509 Parameters
2510 ----------
2511 seed : int
2512 Seed for numpy.random.seed
2514 Examples
2515 --------
2517 with RNGContext(42):
2518 np.random.randn()
2519 """
2521 def __init__(self, seed):
2522 self.seed = seed
2524 def __enter__(self):
2526 self.start_state = np.random.get_state()
2527 np.random.seed(self.seed)
2529 def __exit__(self, exc_type, exc_value, traceback):
2531 np.random.set_state(self.start_state)
2534@contextmanager
2535def with_csv_dialect(name, **kwargs):
2536 """
2537 Context manager to temporarily register a CSV dialect for parsing CSV.
2539 Parameters
2540 ----------
2541 name : str
2542 The name of the dialect.
2543 kwargs : mapping
2544 The parameters for the dialect.
2546 Raises
2547 ------
2548 ValueError : the name of the dialect conflicts with a builtin one.
2550 See Also
2551 --------
2552 csv : Python's CSV library.
2553 """
2554 import csv
2556 _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
2558 if name in _BUILTIN_DIALECTS:
2559 raise ValueError("Cannot override builtin dialect.")
2561 csv.register_dialect(name, **kwargs)
2562 yield
2563 csv.unregister_dialect(name)
2566@contextmanager
2567def use_numexpr(use, min_elements=None):
2568 from pandas.core.computation import expressions as expr
2570 if min_elements is None:
2571 min_elements = expr._MIN_ELEMENTS
2573 olduse = expr._USE_NUMEXPR
2574 oldmin = expr._MIN_ELEMENTS
2575 expr.set_use_numexpr(use)
2576 expr._MIN_ELEMENTS = min_elements
2577 yield
2578 expr._MIN_ELEMENTS = oldmin
2579 expr.set_use_numexpr(olduse)
2582def test_parallel(num_threads=2, kwargs_list=None):
2583 """Decorator to run the same function multiple times in parallel.
2585 Parameters
2586 ----------
2587 num_threads : int, optional
2588 The number of times the function is run in parallel.
2589 kwargs_list : list of dicts, optional
2590 The list of kwargs to update original
2591 function kwargs on different threads.
2592 Notes
2593 -----
2594 This decorator does not pass the return value of the decorated function.
2596 Original from scikit-image:
2598 https://github.com/scikit-image/scikit-image/pull/1519
2600 """
2602 assert num_threads > 0
2603 has_kwargs_list = kwargs_list is not None
2604 if has_kwargs_list:
2605 assert len(kwargs_list) == num_threads
2606 import threading
2608 def wrapper(func):
2609 @wraps(func)
2610 def inner(*args, **kwargs):
2611 if has_kwargs_list:
2612 update_kwargs = lambda i: dict(kwargs, **kwargs_list[i])
2613 else:
2614 update_kwargs = lambda i: kwargs
2615 threads = []
2616 for i in range(num_threads):
2617 updated_kwargs = update_kwargs(i)
2618 thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs)
2619 threads.append(thread)
2620 for thread in threads:
2621 thread.start()
2622 for thread in threads:
2623 thread.join()
2625 return inner
2627 return wrapper
2630class SubclassedSeries(Series):
2631 _metadata = ["testattr", "name"]
2633 @property
2634 def _constructor(self):
2635 return SubclassedSeries
2637 @property
2638 def _constructor_expanddim(self):
2639 return SubclassedDataFrame
2642class SubclassedDataFrame(DataFrame):
2643 _metadata = ["testattr"]
2645 @property
2646 def _constructor(self):
2647 return SubclassedDataFrame
2649 @property
2650 def _constructor_sliced(self):
2651 return SubclassedSeries
2654class SubclassedCategorical(Categorical):
2655 @property
2656 def _constructor(self):
2657 return SubclassedCategorical
2660@contextmanager
2661def set_timezone(tz: str):
2662 """
2663 Context manager for temporarily setting a timezone.
2665 Parameters
2666 ----------
2667 tz : str
2668 A string representing a valid timezone.
2670 Examples
2671 --------
2673 >>> from datetime import datetime
2674 >>> from dateutil.tz import tzlocal
2675 >>> tzlocal().tzname(datetime.now())
2676 'IST'
2678 >>> with set_timezone('US/Eastern'):
2679 ... tzlocal().tzname(datetime.now())
2680 ...
2681 'EDT'
2682 """
2684 import os
2685 import time
2687 def setTZ(tz):
2688 if tz is None:
2689 try:
2690 del os.environ["TZ"]
2691 except KeyError:
2692 pass
2693 else:
2694 os.environ["TZ"] = tz
2695 time.tzset()
2697 orig_tz = os.environ.get("TZ")
2698 setTZ(tz)
2699 try:
2700 yield
2701 finally:
2702 setTZ(orig_tz)
2705def _make_skipna_wrapper(alternative, skipna_alternative=None):
2706 """
2707 Create a function for calling on an array.
2709 Parameters
2710 ----------
2711 alternative : function
2712 The function to be called on the array with no NaNs.
2713 Only used when 'skipna_alternative' is None.
2714 skipna_alternative : function
2715 The function to be called on the original array
2717 Returns
2718 -------
2719 function
2720 """
2721 if skipna_alternative:
2723 def skipna_wrapper(x):
2724 return skipna_alternative(x.values)
2726 else:
2728 def skipna_wrapper(x):
2729 nona = x.dropna()
2730 if len(nona) == 0:
2731 return np.nan
2732 return alternative(nona)
2734 return skipna_wrapper
2737def convert_rows_list_to_csv_str(rows_list: List[str]):
2738 """
2739 Convert list of CSV rows to single CSV-formatted string for current OS.
2741 This method is used for creating expected value of to_csv() method.
2743 Parameters
2744 ----------
2745 rows_list : List[str]
2746 Each element represents the row of csv.
2748 Returns
2749 -------
2750 str
2751 Expected output of to_csv() in current OS.
2752 """
2753 sep = os.linesep
2754 expected = sep.join(rows_list) + sep
2755 return expected