Created
April 15, 2025 01:39
-
-
Save tommy-ca/c9ee54b469332f5bc3ba9f4ec5265601 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| This file is a merged representation of the entire codebase, combined into a single document by Repomix. | |
| <file_summary> | |
| This section contains a summary of this file. | |
| <purpose> | |
| This file contains a packed representation of the entire repository's contents. | |
| It is designed to be easily consumable by AI systems for analysis, code review, | |
| or other automated processes. | |
| </purpose> | |
| <file_format> | |
| The content is organized as follows: | |
| 1. This summary section | |
| 2. Repository information | |
| 3. Directory structure | |
| 4. Repository files, each consisting of: | |
| - File path as an attribute | |
| - Full contents of the file | |
| </file_format> | |
| <usage_guidelines> | |
| - This file should be treated as read-only. Any changes should be made to the | |
| original repository files, not this packed version. | |
| - When processing this file, use the file path to distinguish | |
| between different files in the repository. | |
| - Be aware that this file may contain sensitive information. Handle it with | |
| the same level of security as you would the original repository. | |
| </usage_guidelines> | |
| <notes> | |
| - Some files may have been excluded based on .gitignore rules and Repomix's configuration | |
| - Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files | |
| - Files matching patterns in .gitignore are excluded | |
| - Files matching default ignore patterns are excluded | |
| - Files are sorted by Git change count (files with more changes are at the bottom) | |
| </notes> | |
| <additional_info> | |
| </additional_info> | |
| </file_summary> | |
| <directory_structure> | |
| base/ | |
| grouping/ | |
| __init__.py | |
| base.py | |
| nb.py | |
| resampling/ | |
| __init__.py | |
| base.py | |
| nb.py | |
| __init__.py | |
| accessors.py | |
| chunking.py | |
| combining.py | |
| decorators.py | |
| flex_indexing.py | |
| indexes.py | |
| indexing.py | |
| merging.py | |
| preparing.py | |
| reshaping.py | |
| wrapping.py | |
| data/ | |
| custom/ | |
| __init__.py | |
| alpaca.py | |
| av.py | |
| bento.py | |
| binance.py | |
| ccxt.py | |
| csv.py | |
| custom.py | |
| db.py | |
| duckdb.py | |
| feather.py | |
| file.py | |
| finpy.py | |
| gbm_ohlc.py | |
| gbm.py | |
| hdf.py | |
| local.py | |
| ndl.py | |
| parquet.py | |
| polygon.py | |
| random_ohlc.py | |
| random.py | |
| remote.py | |
| sql.py | |
| synthetic.py | |
| tv.py | |
| yf.py | |
| __init__.py | |
| base.py | |
| decorators.py | |
| nb.py | |
| saver.py | |
| updater.py | |
| generic/ | |
| nb/ | |
| __init__.py | |
| apply_reduce.py | |
| base.py | |
| iter_.py | |
| patterns.py | |
| records.py | |
| rolling.py | |
| sim_range.py | |
| splitting/ | |
| __init__.py | |
| base.py | |
| decorators.py | |
| nb.py | |
| purged.py | |
| sklearn_.py | |
| __init__.py | |
| accessors.py | |
| analyzable.py | |
| decorators.py | |
| drawdowns.py | |
| enums.py | |
| plots_builder.py | |
| plotting.py | |
| price_records.py | |
| ranges.py | |
| sim_range.py | |
| stats_builder.py | |
| indicators/ | |
| custom/ | |
| __init__.py | |
| adx.py | |
| atr.py | |
| bbands.py | |
| hurst.py | |
| ma.py | |
| macd.py | |
| msd.py | |
| obv.py | |
| ols.py | |
| patsim.py | |
| pivotinfo.py | |
| rsi.py | |
| sigdet.py | |
| stoch.py | |
| supertrend.py | |
| vwap.py | |
| __init__.py | |
| configs.py | |
| enums.py | |
| expr.py | |
| factory.py | |
| nb.py | |
| talib_.py | |
| labels/ | |
| generators/ | |
| __init__.py | |
| bolb.py | |
| fixlb.py | |
| fmax.py | |
| fmean.py | |
| fmin.py | |
| fstd.py | |
| meanlb.py | |
| pivotlb.py | |
| trendlb.py | |
| __init__.py | |
| enums.py | |
| nb.py | |
| ohlcv/ | |
| __init__.py | |
| accessors.py | |
| enums.py | |
| nb.py | |
| portfolio/ | |
| nb/ | |
| __init__.py | |
| analysis.py | |
| core.py | |
| ctx_helpers.py | |
| from_order_func.py | |
| from_orders.py | |
| from_signals.py | |
| iter_.py | |
| records.py | |
| pfopt/ | |
| __init__.py | |
| base.py | |
| nb.py | |
| records.py | |
| __init__.py | |
| base.py | |
| call_seq.py | |
| chunking.py | |
| decorators.py | |
| enums.py | |
| logs.py | |
| orders.py | |
| preparing.py | |
| trades.py | |
| px/ | |
| __init__.py | |
| accessors.py | |
| decorators.py | |
| records/ | |
| __init__.py | |
| base.py | |
| chunking.py | |
| col_mapper.py | |
| decorators.py | |
| mapped_array.py | |
| nb.py | |
| registries/ | |
| __init__.py | |
| ca_registry.py | |
| ch_registry.py | |
| jit_registry.py | |
| pbar_registry.py | |
| returns/ | |
| __init__.py | |
| accessors.py | |
| enums.py | |
| nb.py | |
| qs_adapter.py | |
| signals/ | |
| generators/ | |
| __init__.py | |
| ohlcstcx.py | |
| ohlcstx.py | |
| rand.py | |
| randnx.py | |
| randx.py | |
| rprob.py | |
| rprobcx.py | |
| rprobnx.py | |
| rprobx.py | |
| stcx.py | |
| stx.py | |
| __init__.py | |
| accessors.py | |
| enums.py | |
| factory.py | |
| nb.py | |
| templates/ | |
| dark.json | |
| light.json | |
| seaborn.json | |
| utils/ | |
| knowledge/ | |
| __init__.py | |
| asset_pipelines.py | |
| base_asset_funcs.py | |
| base_assets.py | |
| chatting.py | |
| custom_asset_funcs.py | |
| custom_assets.py | |
| formatting.py | |
| __init__.py | |
| annotations.py | |
| array_.py | |
| attr_.py | |
| base.py | |
| caching.py | |
| chaining.py | |
| checks.py | |
| chunking.py | |
| colors.py | |
| config.py | |
| cutting.py | |
| datetime_.py | |
| datetime_nb.py | |
| decorators.py | |
| enum_.py | |
| eval_.py | |
| execution.py | |
| figure.py | |
| formatting.py | |
| hashing.py | |
| image_.py | |
| jitting.py | |
| magic_decorators.py | |
| mapping.py | |
| math_.py | |
| merging.py | |
| module_.py | |
| params.py | |
| parsing.py | |
| path_.py | |
| pbar.py | |
| pickling.py | |
| profiling.py | |
| random_.py | |
| requests_.py | |
| schedule_.py | |
| search_.py | |
| selection.py | |
| tagging.py | |
| telegram.py | |
| template.py | |
| warnings_.py | |
| __init__.py | |
| _dtypes.py | |
| _opt_deps.py | |
| _settings.py | |
| _typing.py | |
| _version.py | |
| accessors.py | |
| </directory_structure> | |
| <files> | |
| This section contains the contents of the repository's files. | |
| <file path="base/grouping/__init__.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Modules with classes and utilities for grouping.""" | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from vectorbtpro.base.grouping.base import * | |
| from vectorbtpro.base.grouping.nb import * | |
| </file> | |
| <file path="base/grouping/base.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Base classes and functions for grouping. | |
| Class `Grouper` stores metadata related to grouping index. It can return, for example, | |
| the number of groups, the start indices of groups, and other information useful for reducing | |
| operations that utilize grouping. It also allows to dynamically enable/disable/modify groups | |
| and checks whether a certain operation is permitted.""" | |
| import numpy as np | |
| import pandas as pd | |
| from pandas.core.groupby import GroupBy as PandasGroupBy | |
| from pandas.core.resample import Resampler as PandasResampler | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._dtypes import * | |
| from vectorbtpro.base import indexes | |
| from vectorbtpro.base.grouping import nb | |
| from vectorbtpro.base.indexes import ExceptLevel | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils.array_ import is_sorted | |
| from vectorbtpro.utils.config import Configured | |
| from vectorbtpro.utils.decorators import cached_method | |
| from vectorbtpro.utils.template import CustomTemplate | |
| __all__ = [ | |
| "Grouper", | |
| ] | |
| GroupByT = tp.Union[None, bool, tp.Index] | |
| GrouperT = tp.TypeVar("GrouperT", bound="Grouper") | |
| class Grouper(Configured): | |
| """Class that exposes methods to group index. | |
| `group_by` can be: | |
| * boolean (False for no grouping, True for one group), | |
| * integer (level by position), | |
| * string (level by name), | |
| * sequence of integers or strings that is shorter than `index` (multiple levels), | |
| * any other sequence that has the same length as `index` (group per index). | |
| Set `allow_enable` to False to prohibit grouping if `Grouper.group_by` is None. | |
| Set `allow_disable` to False to prohibit disabling of grouping if `Grouper.group_by` is not None. | |
| Set `allow_modify` to False to prohibit modifying groups (you can still change their labels). | |
| All properties are read-only to enable caching.""" | |
| @classmethod | |
| def group_by_to_index( | |
| cls, | |
| index: tp.Index, | |
| group_by: tp.GroupByLike, | |
| def_lvl_name: tp.Hashable = "group", | |
| ) -> GroupByT: | |
| """Convert mapper `group_by` to `pd.Index`. | |
| !!! note | |
| Index and mapper must have the same length.""" | |
| if group_by is None or group_by is False: | |
| return group_by | |
| if isinstance(group_by, CustomTemplate): | |
| group_by = group_by.substitute(context=dict(index=index), strict=True, eval_id="group_by") | |
| if group_by is True: | |
| group_by = pd.Index(["group"] * len(index), name=def_lvl_name) | |
| elif isinstance(index, pd.MultiIndex) or isinstance(group_by, (ExceptLevel, int, str)): | |
| if isinstance(group_by, ExceptLevel): | |
| except_levels = group_by.value | |
| if isinstance(except_levels, (int, str)): | |
| except_levels = [except_levels] | |
| new_group_by = [] | |
| for i, name in enumerate(index.names): | |
| if i not in except_levels and name not in except_levels: | |
| new_group_by.append(name) | |
| if len(new_group_by) == 0: | |
| group_by = pd.Index(["group"] * len(index), name=def_lvl_name) | |
| else: | |
| if len(new_group_by) == 1: | |
| new_group_by = new_group_by[0] | |
| group_by = indexes.select_levels(index, new_group_by) | |
| elif isinstance(group_by, (int, str)): | |
| group_by = indexes.select_levels(index, group_by) | |
| elif ( | |
| isinstance(group_by, (tuple, list)) | |
| and not isinstance(group_by[0], pd.Index) | |
| and len(group_by) <= len(index.names) | |
| ): | |
| try: | |
| group_by = indexes.select_levels(index, group_by) | |
| except (IndexError, KeyError): | |
| pass | |
| if not isinstance(group_by, pd.Index): | |
| if isinstance(group_by[0], pd.Index): | |
| group_by = pd.MultiIndex.from_arrays(group_by) | |
| else: | |
| group_by = pd.Index(group_by, name=def_lvl_name) | |
| if len(group_by) != len(index): | |
| raise ValueError("group_by and index must have the same length") | |
| return group_by | |
| @classmethod | |
| def group_by_to_groups_and_index( | |
| cls, | |
| index: tp.Index, | |
| group_by: tp.GroupByLike, | |
| def_lvl_name: tp.Hashable = "group", | |
| ) -> tp.Tuple[tp.Array1d, tp.Index]: | |
| """Return array of group indices pointing to the original index, and grouped index.""" | |
| if group_by is None or group_by is False: | |
| return np.arange(len(index)), index | |
| group_by = cls.group_by_to_index(index, group_by, def_lvl_name) | |
| codes, uniques = pd.factorize(group_by) | |
| if not isinstance(uniques, pd.Index): | |
| new_index = pd.Index(uniques) | |
| else: | |
| new_index = uniques | |
| if isinstance(group_by, pd.MultiIndex): | |
| new_index.names = group_by.names | |
| elif isinstance(group_by, (pd.Index, pd.Series)): | |
| new_index.name = group_by.name | |
| return codes, new_index | |
| @classmethod | |
| def iter_group_lens(cls, group_lens: tp.GroupLens) -> tp.Iterator[tp.GroupIdxs]: | |
| """Iterate over indices of each group in group lengths.""" | |
| group_end_idxs = np.cumsum(group_lens) | |
| group_start_idxs = group_end_idxs - group_lens | |
| for group in range(len(group_lens)): | |
| from_col = group_start_idxs[group] | |
| to_col = group_end_idxs[group] | |
| yield np.arange(from_col, to_col) | |
| @classmethod | |
| def iter_group_map(cls, group_map: tp.GroupMap) -> tp.Iterator[tp.GroupIdxs]: | |
| """Iterate over indices of each group in a group map.""" | |
| group_idxs, group_lens = group_map | |
| group_start = 0 | |
| group_end = 0 | |
| for group in range(len(group_lens)): | |
| group_len = group_lens[group] | |
| group_end += group_len | |
| yield group_idxs[group_start:group_end] | |
| group_start += group_len | |
| @classmethod | |
| def from_pd_group_by( | |
| cls: tp.Type[GrouperT], | |
| pd_group_by: tp.PandasGroupByLike, | |
| **kwargs, | |
| ) -> GrouperT: | |
| """Build a `Grouper` instance from a pandas `GroupBy` object. | |
| Indices are stored under `index` and group labels under `group_by`.""" | |
| from vectorbtpro.base.merging import concat_arrays | |
| if not isinstance(pd_group_by, (PandasGroupBy, PandasResampler)): | |
| raise TypeError("pd_group_by must be an instance of GroupBy or Resampler") | |
| indices = list(pd_group_by.indices.values()) | |
| group_lens = np.asarray(list(map(len, indices))) | |
| groups = np.full(int(np.sum(group_lens)), 0, dtype=int_) | |
| group_start_idxs = np.cumsum(group_lens)[1:] - group_lens[1:] | |
| groups[group_start_idxs] = 1 | |
| groups = np.cumsum(groups) | |
| index = pd.Index(concat_arrays(indices)) | |
| group_by = pd.Index(list(pd_group_by.indices.keys()), name="group")[groups] | |
| return cls( | |
| index=index, | |
| group_by=group_by, | |
| **kwargs, | |
| ) | |
| def __init__( | |
| self, | |
| index: tp.Index, | |
| group_by: tp.GroupByLike = None, | |
| def_lvl_name: tp.Hashable = "group", | |
| allow_enable: bool = True, | |
| allow_disable: bool = True, | |
| allow_modify: bool = True, | |
| **kwargs, | |
| ) -> None: | |
| if not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| if group_by is None or group_by is False: | |
| group_by = None | |
| else: | |
| group_by = self.group_by_to_index(index, group_by, def_lvl_name=def_lvl_name) | |
| self._index = index | |
| self._group_by = group_by | |
| self._def_lvl_name = def_lvl_name | |
| self._allow_enable = allow_enable | |
| self._allow_disable = allow_disable | |
| self._allow_modify = allow_modify | |
| Configured.__init__( | |
| self, | |
| index=index, | |
| group_by=group_by, | |
| def_lvl_name=def_lvl_name, | |
| allow_enable=allow_enable, | |
| allow_disable=allow_disable, | |
| allow_modify=allow_modify, | |
| **kwargs, | |
| ) | |
| @property | |
| def index(self) -> tp.Index: | |
| """Original index.""" | |
| return self._index | |
| @property | |
| def group_by(self) -> GroupByT: | |
| """Mapper for grouping.""" | |
| return self._group_by | |
| @property | |
| def def_lvl_name(self) -> tp.Hashable: | |
| """Default level name.""" | |
| return self._def_lvl_name | |
| @property | |
| def allow_enable(self) -> bool: | |
| """Whether to allow enabling grouping.""" | |
| return self._allow_enable | |
| @property | |
| def allow_disable(self) -> bool: | |
| """Whether to allow disabling grouping.""" | |
| return self._allow_disable | |
| @property | |
| def allow_modify(self) -> bool: | |
| """Whether to allow changing groups.""" | |
| return self._allow_modify | |
| def is_grouped(self, group_by: tp.GroupByLike = None) -> bool: | |
| """Check whether index are grouped.""" | |
| if group_by is False: | |
| return False | |
| if group_by is None: | |
| group_by = self.group_by | |
| return group_by is not None | |
| def is_grouping_enabled(self, group_by: tp.GroupByLike = None) -> bool: | |
| """Check whether grouping has been enabled.""" | |
| return self.group_by is None and self.is_grouped(group_by=group_by) | |
| def is_grouping_disabled(self, group_by: tp.GroupByLike = None) -> bool: | |
| """Check whether grouping has been disabled.""" | |
| return self.group_by is not None and not self.is_grouped(group_by=group_by) | |
| @cached_method(whitelist=True) | |
| def is_grouping_modified(self, group_by: tp.GroupByLike = None) -> bool: | |
| """Check whether grouping has been modified. | |
| Doesn't care if grouping labels have been changed.""" | |
| if group_by is None or (group_by is False and self.group_by is None): | |
| return False | |
| group_by = self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name) | |
| if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index): | |
| if not pd.Index.equals(group_by, self.group_by): | |
| groups1 = self.group_by_to_groups_and_index( | |
| self.index, | |
| group_by, | |
| def_lvl_name=self.def_lvl_name, | |
| )[0] | |
| groups2 = self.group_by_to_groups_and_index( | |
| self.index, | |
| self.group_by, | |
| def_lvl_name=self.def_lvl_name, | |
| )[0] | |
| if not np.array_equal(groups1, groups2): | |
| return True | |
| return False | |
| return True | |
| @cached_method(whitelist=True) | |
| def is_grouping_changed(self, group_by: tp.GroupByLike = None) -> bool: | |
| """Check whether grouping has been changed in any way.""" | |
| if group_by is None or (group_by is False and self.group_by is None): | |
| return False | |
| if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index): | |
| if pd.Index.equals(group_by, self.group_by): | |
| return False | |
| return True | |
| def is_group_count_changed(self, group_by: tp.GroupByLike = None) -> bool: | |
| """Check whether the number of groups has changed.""" | |
| if group_by is None or (group_by is False and self.group_by is None): | |
| return False | |
| if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index): | |
| return len(group_by) != len(self.group_by) | |
| return True | |
| def check_group_by( | |
| self, | |
| group_by: tp.GroupByLike = None, | |
| allow_enable: tp.Optional[bool] = None, | |
| allow_disable: tp.Optional[bool] = None, | |
| allow_modify: tp.Optional[bool] = None, | |
| ) -> None: | |
| """Check passed `group_by` object against restrictions.""" | |
| if allow_enable is None: | |
| allow_enable = self.allow_enable | |
| if allow_disable is None: | |
| allow_disable = self.allow_disable | |
| if allow_modify is None: | |
| allow_modify = self.allow_modify | |
| if self.is_grouping_enabled(group_by=group_by): | |
| if not allow_enable: | |
| raise ValueError("Enabling grouping is not allowed") | |
| elif self.is_grouping_disabled(group_by=group_by): | |
| if not allow_disable: | |
| raise ValueError("Disabling grouping is not allowed") | |
| elif self.is_grouping_modified(group_by=group_by): | |
| if not allow_modify: | |
| raise ValueError("Modifying groups is not allowed") | |
| def resolve_group_by(self, group_by: tp.GroupByLike = None, **kwargs) -> GroupByT: | |
| """Resolve `group_by` from either object variable or keyword argument.""" | |
| if group_by is None: | |
| group_by = self.group_by | |
| if group_by is False and self.group_by is None: | |
| group_by = None | |
| self.check_group_by(group_by=group_by, **kwargs) | |
| return self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name) | |
| @cached_method(whitelist=True) | |
| def get_groups_and_index(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.Tuple[tp.Array1d, tp.Index]: | |
| """See `Grouper.group_by_to_groups_and_index`.""" | |
| group_by = self.resolve_group_by(group_by=group_by, **kwargs) | |
| return self.group_by_to_groups_and_index(self.index, group_by, def_lvl_name=self.def_lvl_name) | |
| def get_groups(self, **kwargs) -> tp.Array1d: | |
| """Return groups array.""" | |
| return self.get_groups_and_index(**kwargs)[0] | |
| def get_index(self, **kwargs) -> tp.Index: | |
| """Return grouped index.""" | |
| return self.get_groups_and_index(**kwargs)[1] | |
| get_grouped_index = get_index | |
| @property | |
| def grouped_index(self) -> tp.Index: | |
| """Grouped index.""" | |
| return self.get_grouped_index() | |
| def get_stretched_index(self, **kwargs) -> tp.Index: | |
| """Return stretched index.""" | |
| groups, index = self.get_groups_and_index(**kwargs) | |
| return index[groups] | |
| def get_group_count(self, **kwargs) -> int: | |
| """Get number of groups.""" | |
| return len(self.get_index(**kwargs)) | |
| @cached_method(whitelist=True) | |
| def is_sorted(self, group_by: tp.GroupByLike = None, **kwargs) -> bool: | |
| """Return whether groups are monolithic, sorted.""" | |
| group_by = self.resolve_group_by(group_by=group_by, **kwargs) | |
| groups = self.get_groups(group_by=group_by) | |
| return is_sorted(groups) | |
| @cached_method(whitelist=True) | |
| def get_group_lens(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupLens: | |
| """See `vectorbtpro.base.grouping.nb.get_group_lens_nb`.""" | |
| group_by = self.resolve_group_by(group_by=group_by, **kwargs) | |
| if group_by is None or group_by is False: # no grouping | |
| return np.full(len(self.index), 1) | |
| if not self.is_sorted(group_by=group_by): | |
| raise ValueError("group_by must form monolithic, sorted groups") | |
| groups = self.get_groups(group_by=group_by) | |
| func = jit_reg.resolve_option(nb.get_group_lens_nb, jitted) | |
| return func(groups) | |
| def get_group_start_idxs(self, **kwargs) -> tp.Array1d: | |
| """Get first index of each group as an array.""" | |
| group_lens = self.get_group_lens(**kwargs) | |
| return np.cumsum(group_lens) - group_lens | |
| def get_group_end_idxs(self, **kwargs) -> tp.Array1d: | |
| """Get end index of each group as an array.""" | |
| group_lens = self.get_group_lens(**kwargs) | |
| return np.cumsum(group_lens) | |
| @cached_method(whitelist=True) | |
| def get_group_map(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupMap: | |
| """See get_group_map_nb.""" | |
| group_by = self.resolve_group_by(group_by=group_by, **kwargs) | |
| if group_by is None or group_by is False: # no grouping | |
| return np.arange(len(self.index)), np.full(len(self.index), 1) | |
| groups, new_index = self.get_groups_and_index(group_by=group_by) | |
| func = jit_reg.resolve_option(nb.get_group_map_nb, jitted) | |
| return func(groups, len(new_index)) | |
| def iter_group_idxs(self, **kwargs) -> tp.Iterator[tp.GroupIdxs]: | |
| """Iterate over indices of each group.""" | |
| group_map = self.get_group_map(**kwargs) | |
| return self.iter_group_map(group_map) | |
| def iter_groups( | |
| self, | |
| key_as_index: bool = False, | |
| **kwargs, | |
| ) -> tp.Iterator[tp.Tuple[tp.Union[tp.Hashable, pd.Index], tp.GroupIdxs]]: | |
| """Iterate over groups and their indices.""" | |
| index = self.get_index(**kwargs) | |
| for group, group_idxs in enumerate(self.iter_group_idxs(**kwargs)): | |
| if key_as_index: | |
| yield index[[group]], group_idxs | |
| else: | |
| yield index[group], group_idxs | |
| def select_groups(self, group_idxs: tp.Array1d, jitted: tp.JittedOption = None) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Select groups. | |
| Returns indices and new group array. Automatically decides whether to use group lengths or group map.""" | |
| from vectorbtpro.base.reshaping import to_1d_array | |
| if self.is_sorted(): | |
| func = jit_reg.resolve_option(nb.group_lens_select_nb, jitted) | |
| new_group_idxs, new_groups = func(self.get_group_lens(), to_1d_array(group_idxs)) # faster | |
| else: | |
| func = jit_reg.resolve_option(nb.group_map_select_nb, jitted) | |
| new_group_idxs, new_groups = func(self.get_group_map(), to_1d_array(group_idxs)) # more flexible | |
| return new_group_idxs, new_groups | |
| </file> | |
| <file path="base/grouping/nb.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Numba-compiled functions for grouping.""" | |
| import numpy as np | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._dtypes import * | |
| from vectorbtpro.registries.jit_registry import register_jitted | |
| __all__ = [] | |
| GroupByT = tp.Union[None, bool, tp.Index] | |
| @register_jitted(cache=True) | |
| def get_group_lens_nb(groups: tp.Array1d) -> tp.GroupLens: | |
| """Return the count per group. | |
| !!! note | |
| Columns must form monolithic, sorted groups. For unsorted groups, use `get_group_map_nb`.""" | |
| result = np.empty(groups.shape[0], dtype=int_) | |
| j = 0 | |
| last_group = -1 | |
| group_len = 0 | |
| for i in range(groups.shape[0]): | |
| cur_group = groups[i] | |
| if cur_group < last_group: | |
| raise ValueError("Columns must form monolithic, sorted groups") | |
| if cur_group != last_group: | |
| if last_group != -1: | |
| # Process previous group | |
| result[j] = group_len | |
| j += 1 | |
| group_len = 0 | |
| last_group = cur_group | |
| group_len += 1 | |
| if i == groups.shape[0] - 1: | |
| # Process last group | |
| result[j] = group_len | |
| j += 1 | |
| group_len = 0 | |
| return result[:j] | |
| @register_jitted(cache=True) | |
| def get_group_map_nb(groups: tp.Array1d, n_groups: int) -> tp.GroupMap: | |
| """Build the map between groups and indices. | |
| Returns an array with indices segmented by group and an array with group lengths. | |
| Works well for unsorted group arrays.""" | |
| group_lens_out = np.full(n_groups, 0, dtype=int_) | |
| for g in range(groups.shape[0]): | |
| group = groups[g] | |
| group_lens_out[group] += 1 | |
| group_start_idxs = np.cumsum(group_lens_out) - group_lens_out | |
| group_idxs_out = np.empty((groups.shape[0],), dtype=int_) | |
| group_i = np.full(n_groups, 0, dtype=int_) | |
| for g in range(groups.shape[0]): | |
| group = groups[g] | |
| group_idxs_out[group_start_idxs[group] + group_i[group]] = g | |
| group_i[group] += 1 | |
| return group_idxs_out, group_lens_out | |
| @register_jitted(cache=True) | |
| def group_lens_select_nb(group_lens: tp.GroupLens, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Perform indexing on a sorted array using group lengths. | |
| Returns indices of elements corresponding to groups in `new_groups` and a new group array.""" | |
| group_end_idxs = np.cumsum(group_lens) | |
| group_start_idxs = group_end_idxs - group_lens | |
| n_values = np.sum(group_lens[new_groups]) | |
| indices_out = np.empty(n_values, dtype=int_) | |
| group_arr_out = np.empty(n_values, dtype=int_) | |
| j = 0 | |
| for c in range(new_groups.shape[0]): | |
| from_r = group_start_idxs[new_groups[c]] | |
| to_r = group_end_idxs[new_groups[c]] | |
| if from_r == to_r: | |
| continue | |
| rang = np.arange(from_r, to_r) | |
| indices_out[j : j + rang.shape[0]] = rang | |
| group_arr_out[j : j + rang.shape[0]] = c | |
| j += rang.shape[0] | |
| return indices_out, group_arr_out | |
| @register_jitted(cache=True) | |
| def group_map_select_nb(group_map: tp.GroupMap, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Perform indexing using group map.""" | |
| group_idxs, group_lens = group_map | |
| group_start_idxs = np.cumsum(group_lens) - group_lens | |
| total_count = np.sum(group_lens[new_groups]) | |
| indices_out = np.empty(total_count, dtype=int_) | |
| group_arr_out = np.empty(total_count, dtype=int_) | |
| j = 0 | |
| for new_group_i in range(len(new_groups)): | |
| new_group = new_groups[new_group_i] | |
| group_len = group_lens[new_group] | |
| if group_len == 0: | |
| continue | |
| group_start_idx = group_start_idxs[new_group] | |
| idxs = group_idxs[group_start_idx : group_start_idx + group_len] | |
| indices_out[j : j + group_len] = idxs | |
| group_arr_out[j : j + group_len] = new_group_i | |
| j += group_len | |
| return indices_out, group_arr_out | |
| @register_jitted(cache=True) | |
| def group_by_evenly_nb(n: int, n_splits: int) -> tp.Array1d: | |
| """Get `group_by` from evenly splitting a space of values.""" | |
| out = np.empty(n, dtype=int_) | |
| for i in range(n): | |
| out[i] = i * n_splits // n + n_splits // (2 * n) | |
| return out | |
| </file> | |
| <file path="base/resampling/__init__.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Modules with classes and utilities for resampling.""" | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from vectorbtpro.base.resampling.base import * | |
| from vectorbtpro.base.resampling.nb import * | |
| </file> | |
| <file path="base/resampling/base.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Base classes and functions for resampling.""" | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.indexes import repeat_index | |
| from vectorbtpro.base.resampling import nb | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.config import Configured | |
| from vectorbtpro.utils.decorators import cached_property, hybrid_method | |
| from vectorbtpro.utils.warnings_ import warn | |
| __all__ = [ | |
| "Resampler", | |
| ] | |
| ResamplerT = tp.TypeVar("ResamplerT", bound="Resampler") | |
| class Resampler(Configured): | |
| """Class that exposes methods to resample index. | |
| Args: | |
| source_index (index_like): Index being resampled. | |
| target_index (index_like): Index resulted from resampling. | |
| source_freq (frequency_like or bool): Frequency or date offset of the source index. | |
| Set to False to force-set the frequency to None. | |
| target_freq (frequency_like or bool): Frequency or date offset of the target index. | |
| Set to False to force-set the frequency to None. | |
| silence_warnings (bool): Whether to silence all warnings.""" | |
| def __init__( | |
| self, | |
| source_index: tp.IndexLike, | |
| target_index: tp.IndexLike, | |
| source_freq: tp.Union[None, bool, tp.FrequencyLike] = None, | |
| target_freq: tp.Union[None, bool, tp.FrequencyLike] = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> None: | |
| source_index = dt.prepare_dt_index(source_index) | |
| target_index = dt.prepare_dt_index(target_index) | |
| infer_source_freq = True | |
| if isinstance(source_freq, bool): | |
| if not source_freq: | |
| infer_source_freq = False | |
| source_freq = None | |
| infer_target_freq = True | |
| if isinstance(target_freq, bool): | |
| if not target_freq: | |
| infer_target_freq = False | |
| target_freq = None | |
| if infer_source_freq: | |
| source_freq = dt.infer_index_freq(source_index, freq=source_freq) | |
| if infer_target_freq: | |
| target_freq = dt.infer_index_freq(target_index, freq=target_freq) | |
| self._source_index = source_index | |
| self._target_index = target_index | |
| self._source_freq = source_freq | |
| self._target_freq = target_freq | |
| self._silence_warnings = silence_warnings | |
| Configured.__init__( | |
| self, | |
| source_index=source_index, | |
| target_index=target_index, | |
| source_freq=source_freq, | |
| target_freq=target_freq, | |
| silence_warnings=silence_warnings, | |
| **kwargs, | |
| ) | |
| @classmethod | |
| def from_pd_resampler( | |
| cls: tp.Type[ResamplerT], | |
| pd_resampler: tp.PandasResampler, | |
| source_freq: tp.Optional[tp.FrequencyLike] = None, | |
| silence_warnings: bool = True, | |
| ) -> ResamplerT: | |
| """Build `Resampler` from | |
| [pandas.core.resample.Resampler](https://pandas.pydata.org/docs/reference/resampling.html). | |
| """ | |
| target_index = pd_resampler.count().index | |
| return cls( | |
| source_index=pd_resampler.obj.index, | |
| target_index=target_index, | |
| source_freq=source_freq, | |
| target_freq=None, | |
| silence_warnings=silence_warnings, | |
| ) | |
| @classmethod | |
| def from_pd_resample( | |
| cls: tp.Type[ResamplerT], | |
| source_index: tp.IndexLike, | |
| *args, | |
| source_freq: tp.Optional[tp.FrequencyLike] = None, | |
| silence_warnings: bool = True, | |
| **kwargs, | |
| ) -> ResamplerT: | |
| """Build `Resampler` from | |
| [pandas.DataFrame.resample](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html). | |
| """ | |
| pd_resampler = pd.Series(index=source_index, dtype=object).resample(*args, **kwargs) | |
| return cls.from_pd_resampler(pd_resampler, source_freq=source_freq, silence_warnings=silence_warnings) | |
| @classmethod | |
| def from_date_range( | |
| cls: tp.Type[ResamplerT], | |
| source_index: tp.IndexLike, | |
| *args, | |
| source_freq: tp.Optional[tp.FrequencyLike] = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> ResamplerT: | |
| """Build `Resampler` from `vectorbtpro.utils.datetime_.date_range`.""" | |
| target_index = dt.date_range(*args, **kwargs) | |
| return cls( | |
| source_index=source_index, | |
| target_index=target_index, | |
| source_freq=source_freq, | |
| target_freq=None, | |
| silence_warnings=silence_warnings, | |
| ) | |
| @property | |
| def source_index(self) -> tp.Index: | |
| """Index being resampled.""" | |
| return self._source_index | |
| @property | |
| def target_index(self) -> tp.Index: | |
| """Index resulted from resampling.""" | |
| return self._target_index | |
| @property | |
| def source_freq(self) -> tp.AnyPandasFrequency: | |
| """Frequency or date offset of the source index.""" | |
| return self._source_freq | |
| @property | |
| def target_freq(self) -> tp.AnyPandasFrequency: | |
| """Frequency or date offset of the target index.""" | |
| return self._target_freq | |
| @property | |
| def silence_warnings(self) -> bool: | |
| """Frequency or date offset of the target index.""" | |
| from vectorbtpro._settings import settings | |
| resampling_cfg = settings["resampling"] | |
| silence_warnings = self._silence_warnings | |
| if silence_warnings is None: | |
| silence_warnings = resampling_cfg["silence_warnings"] | |
| return silence_warnings | |
| def get_np_source_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency: | |
| """Frequency or date offset of the source index in NumPy format.""" | |
| if silence_warnings is None: | |
| silence_warnings = self.silence_warnings | |
| warned = False | |
| source_freq = self.source_freq | |
| if source_freq is not None: | |
| if not isinstance(source_freq, (int, float)): | |
| try: | |
| source_freq = dt.to_timedelta64(source_freq) | |
| except ValueError as e: | |
| if not silence_warnings: | |
| warn(f"Cannot convert {source_freq} to np.timedelta64. Setting to None.") | |
| warned = True | |
| source_freq = None | |
| if source_freq is None: | |
| if not warned and not silence_warnings: | |
| warn("Using right bound of source index without frequency. Set source frequency.") | |
| return source_freq | |
| def get_np_target_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency: | |
| """Frequency or date offset of the target index in NumPy format.""" | |
| if silence_warnings is None: | |
| silence_warnings = self.silence_warnings | |
| warned = False | |
| target_freq = self.target_freq | |
| if target_freq is not None: | |
| if not isinstance(target_freq, (int, float)): | |
| try: | |
| target_freq = dt.to_timedelta64(target_freq) | |
| except ValueError as e: | |
| if not silence_warnings: | |
| warn(f"Cannot convert {target_freq} to np.timedelta64. Setting to None.") | |
| warned = True | |
| target_freq = None | |
| if target_freq is None: | |
| if not warned and not silence_warnings: | |
| warn("Using right bound of target index without frequency. Set target frequency.") | |
| return target_freq | |
| @classmethod | |
| def get_lbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index: | |
| """Get the left bound of a datetime index. | |
| If `freq` is None, calculates the leftmost bound.""" | |
| index = dt.prepare_dt_index(index) | |
| checks.assert_instance_of(index, pd.DatetimeIndex) | |
| if freq is not None: | |
| return index.shift(-1, freq=freq) + pd.Timedelta(1, "ns") | |
| min_ts = pd.DatetimeIndex([pd.Timestamp.min.tz_localize(index.tz)]) | |
| return (index[:-1] + pd.Timedelta(1, "ns")).append(min_ts) | |
| @classmethod | |
| def get_rbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index: | |
| """Get the right bound of a datetime index. | |
| If `freq` is None, calculates the rightmost bound.""" | |
| index = dt.prepare_dt_index(index) | |
| checks.assert_instance_of(index, pd.DatetimeIndex) | |
| if freq is not None: | |
| return index.shift(1, freq=freq) - pd.Timedelta(1, "ns") | |
| max_ts = pd.DatetimeIndex([pd.Timestamp.max.tz_localize(index.tz)]) | |
| return (index[1:] - pd.Timedelta(1, "ns")).append(max_ts) | |
| @cached_property | |
| def source_lbound_index(self) -> tp.Index: | |
| """Get the left bound of the source datetime index.""" | |
| return self.get_lbound_index(self.source_index, freq=self.source_freq) | |
| @cached_property | |
| def source_rbound_index(self) -> tp.Index: | |
| """Get the right bound of the source datetime index.""" | |
| return self.get_rbound_index(self.source_index, freq=self.source_freq) | |
| @cached_property | |
| def target_lbound_index(self) -> tp.Index: | |
| """Get the left bound of the target datetime index.""" | |
| return self.get_lbound_index(self.target_index, freq=self.target_freq) | |
| @cached_property | |
| def target_rbound_index(self) -> tp.Index: | |
| """Get the right bound of the target datetime index.""" | |
| return self.get_rbound_index(self.target_index, freq=self.target_freq) | |
| def map_to_target_index( | |
| self, | |
| before: bool = False, | |
| raise_missing: bool = True, | |
| return_index: bool = True, | |
| jitted: tp.JittedOption = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.Union[tp.Array1d, tp.Index]: | |
| """See `vectorbtpro.base.resampling.nb.map_to_target_index_nb`.""" | |
| target_freq = self.get_np_target_freq(silence_warnings=silence_warnings) | |
| func = jit_reg.resolve_option(nb.map_to_target_index_nb, jitted) | |
| mapped_arr = func( | |
| self.source_index.values, | |
| self.target_index.values, | |
| target_freq=target_freq, | |
| before=before, | |
| raise_missing=raise_missing, | |
| ) | |
| if return_index: | |
| nan_mask = mapped_arr == -1 | |
| if nan_mask.any(): | |
| mapped_index = self.source_index.to_series().copy() | |
| mapped_index[nan_mask] = np.nan | |
| mapped_index[~nan_mask] = self.target_index[mapped_arr] | |
| mapped_index = pd.Index(mapped_index) | |
| else: | |
| mapped_index = self.target_index[mapped_arr] | |
| return mapped_index | |
| return mapped_arr | |
| def index_difference( | |
| self, | |
| reverse: bool = False, | |
| return_index: bool = True, | |
| jitted: tp.JittedOption = None, | |
| ) -> tp.Union[tp.Array1d, tp.Index]: | |
| """See `vectorbtpro.base.resampling.nb.index_difference_nb`.""" | |
| func = jit_reg.resolve_option(nb.index_difference_nb, jitted) | |
| if reverse: | |
| mapped_arr = func(self.target_index.values, self.source_index.values) | |
| else: | |
| mapped_arr = func(self.source_index.values, self.target_index.values) | |
| if return_index: | |
| return self.target_index[mapped_arr] | |
| return mapped_arr | |
| def map_index_to_source_ranges( | |
| self, | |
| before: bool = False, | |
| jitted: tp.JittedOption = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """See `vectorbtpro.base.resampling.nb.map_index_to_source_ranges_nb`. | |
| If `Resampler.target_freq` is a date offset, sets is to None and gives a warning. | |
| Raises another warning is `target_freq` is None.""" | |
| target_freq = self.get_np_target_freq(silence_warnings=silence_warnings) | |
| func = jit_reg.resolve_option(nb.map_index_to_source_ranges_nb, jitted) | |
| return func( | |
| self.source_index.values, | |
| self.target_index.values, | |
| target_freq=target_freq, | |
| before=before, | |
| ) | |
| @hybrid_method | |
| def map_bounds_to_source_ranges( | |
| cls_or_self, | |
| source_index: tp.Optional[tp.IndexLike] = None, | |
| target_lbound_index: tp.Optional[tp.IndexLike] = None, | |
| target_rbound_index: tp.Optional[tp.IndexLike] = None, | |
| closed_lbound: bool = True, | |
| closed_rbound: bool = False, | |
| skip_not_found: bool = False, | |
| jitted: tp.JittedOption = None, | |
| ) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """See `vectorbtpro.base.resampling.nb.map_bounds_to_source_ranges_nb`. | |
| Either `target_lbound_index` or `target_rbound_index` must be set. | |
| Set `target_lbound_index` and `target_rbound_index` to 'pandas' to use | |
| `Resampler.get_lbound_index` and `Resampler.get_rbound_index` respectively. | |
| Also, both allow providing a single datetime string and will automatically broadcast | |
| to the `Resampler.target_index`.""" | |
| if not isinstance(cls_or_self, type): | |
| if target_lbound_index is None and target_rbound_index is None: | |
| raise ValueError("Either target_lbound_index or target_rbound_index must be set") | |
| if target_lbound_index is not None: | |
| if isinstance(target_lbound_index, str) and target_lbound_index.lower() == "pandas": | |
| target_lbound_index = cls_or_self.target_lbound_index | |
| else: | |
| target_lbound_index = dt.prepare_dt_index(target_lbound_index) | |
| target_rbound_index = cls_or_self.target_index | |
| if target_rbound_index is not None: | |
| target_lbound_index = cls_or_self.target_index | |
| if isinstance(target_rbound_index, str) and target_rbound_index.lower() == "pandas": | |
| target_rbound_index = cls_or_self.target_rbound_index | |
| else: | |
| target_rbound_index = dt.prepare_dt_index(target_rbound_index) | |
| if len(target_lbound_index) == 1 and len(target_rbound_index) > 1: | |
| target_lbound_index = repeat_index(target_lbound_index, len(target_rbound_index)) | |
| elif len(target_lbound_index) > 1 and len(target_rbound_index) == 1: | |
| target_rbound_index = repeat_index(target_rbound_index, len(target_lbound_index)) | |
| else: | |
| source_index = dt.prepare_dt_index(source_index) | |
| target_lbound_index = dt.prepare_dt_index(target_lbound_index) | |
| target_rbound_index = dt.prepare_dt_index(target_rbound_index) | |
| checks.assert_len_equal(target_rbound_index, target_lbound_index) | |
| func = jit_reg.resolve_option(nb.map_bounds_to_source_ranges_nb, jitted) | |
| return func( | |
| source_index.values, | |
| target_lbound_index.values, | |
| target_rbound_index.values, | |
| closed_lbound=closed_lbound, | |
| closed_rbound=closed_rbound, | |
| skip_not_found=skip_not_found, | |
| ) | |
| def resample_source_mask( | |
| self, | |
| source_mask: tp.ArrayLike, | |
| jitted: tp.JittedOption = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.Array1d: | |
| """See `vectorbtpro.base.resampling.nb.resample_source_mask_nb`.""" | |
| from vectorbtpro.base.reshaping import broadcast_array_to | |
| if silence_warnings is None: | |
| silence_warnings = self.silence_warnings | |
| source_mask = broadcast_array_to(source_mask, len(self.source_index)) | |
| source_freq = self.get_np_source_freq(silence_warnings=silence_warnings) | |
| target_freq = self.get_np_target_freq(silence_warnings=silence_warnings) | |
| func = jit_reg.resolve_option(nb.resample_source_mask_nb, jitted) | |
| return func( | |
| source_mask, | |
| self.source_index.values, | |
| self.target_index.values, | |
| source_freq, | |
| target_freq, | |
| ) | |
| def last_before_target_index(self, incl_source: bool = True, jitted: tp.JittedOption = None) -> tp.Array1d: | |
| """See `vectorbtpro.base.resampling.nb.last_before_target_index_nb`.""" | |
| func = jit_reg.resolve_option(nb.last_before_target_index_nb, jitted) | |
| return func(self.source_index.values, self.target_index.values, incl_source=incl_source) | |
| </file> | |
| <file path="base/resampling/nb.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Numba-compiled functions for resampling.""" | |
| import numpy as np | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._dtypes import * | |
| from vectorbtpro.registries.jit_registry import register_jitted | |
| from vectorbtpro.utils.datetime_nb import d_td | |
| __all__ = [] | |
| @register_jitted(cache=True) | |
| def date_range_nb( | |
| start: np.datetime64, | |
| end: np.datetime64, | |
| freq: np.timedelta64 = d_td, | |
| incl_left: bool = True, | |
| incl_right: bool = True, | |
| ) -> tp.Array1d: | |
| """Generate a datetime index with nanosecond precision from a date range. | |
| Inspired by [pandas.date_range](https://pandas.pydata.org/docs/reference/api/pandas.date_range.html).""" | |
| values_len = int(np.floor((end - start) / freq)) + 1 | |
| values = np.empty(values_len, dtype="datetime64[ns]") | |
| for i in range(values_len): | |
| values[i] = start + i * freq | |
| if start == end: | |
| if not incl_left and not incl_right: | |
| values = values[1:-1] | |
| else: | |
| if not incl_left or not incl_right: | |
| if not incl_left and len(values) and values[0] == start: | |
| values = values[1:] | |
| if not incl_right and len(values) and values[-1] == end: | |
| values = values[:-1] | |
| return values | |
| @register_jitted(cache=True) | |
| def map_to_target_index_nb( | |
| source_index: tp.Array1d, | |
| target_index: tp.Array1d, | |
| target_freq: tp.Optional[tp.Scalar] = None, | |
| before: bool = False, | |
| raise_missing: bool = True, | |
| ) -> tp.Array1d: | |
| """Get the index of each from `source_index` in `target_index`. | |
| If `before` is True, applied on elements that come before and including that index. | |
| Otherwise, applied on elements that come after and including that index. | |
| If `raise_missing` is True, will throw an error if an index cannot be mapped. | |
| Otherwise, the element for that index becomes -1.""" | |
| out = np.empty(len(source_index), dtype=int_) | |
| from_j = 0 | |
| for i in range(len(source_index)): | |
| if i > 0 and source_index[i] < source_index[i - 1]: | |
| raise ValueError("Source index must be increasing") | |
| if i > 0 and source_index[i] == source_index[i - 1]: | |
| out[i] = out[i - 1] | |
| found = False | |
| for j in range(from_j, len(target_index)): | |
| if j > 0 and target_index[j] <= target_index[j - 1]: | |
| raise ValueError("Target index must be strictly increasing") | |
| if target_freq is None: | |
| if before and source_index[i] <= target_index[j]: | |
| if j == 0 or target_index[j - 1] < source_index[i]: | |
| out[i] = from_j = j | |
| found = True | |
| break | |
| if not before and target_index[j] <= source_index[i]: | |
| if j == len(target_index) - 1 or source_index[i] < target_index[j + 1]: | |
| out[i] = from_j = j | |
| found = True | |
| break | |
| else: | |
| if before and target_index[j] - target_freq < source_index[i] <= target_index[j]: | |
| out[i] = from_j = j | |
| found = True | |
| break | |
| if not before and target_index[j] <= source_index[i] < target_index[j] + target_freq: | |
| out[i] = from_j = j | |
| found = True | |
| break | |
| if not found: | |
| if raise_missing: | |
| raise ValueError("Resampling failed: cannot map some source indices") | |
| out[i] = -1 | |
| return out | |
| @register_jitted(cache=True) | |
| def index_difference_nb( | |
| source_index: tp.Array1d, | |
| target_index: tp.Array1d, | |
| ) -> tp.Array1d: | |
| """Get the elements in `source_index` not present in `target_index`.""" | |
| out = np.empty(len(source_index), dtype=int_) | |
| from_j = 0 | |
| k = 0 | |
| for i in range(len(source_index)): | |
| if i > 0 and source_index[i] <= source_index[i - 1]: | |
| raise ValueError("Array index must be strictly increasing") | |
| found = False | |
| for j in range(from_j, len(target_index)): | |
| if j > 0 and target_index[j] <= target_index[j - 1]: | |
| raise ValueError("Target index must be strictly increasing") | |
| if source_index[i] < target_index[j]: | |
| break | |
| if source_index[i] == target_index[j]: | |
| from_j = j | |
| found = True | |
| break | |
| from_j = j | |
| if not found: | |
| out[k] = i | |
| k += 1 | |
| return out[:k] | |
| @register_jitted(cache=True) | |
| def map_index_to_source_ranges_nb( | |
| source_index: tp.Array1d, | |
| target_index: tp.Array1d, | |
| target_freq: tp.Optional[tp.Scalar] = None, | |
| before: bool = False, | |
| ) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Get the source bounds that correspond to each target index. | |
| If `target_freq` is not None, the right bound is limited by the frequency in `target_freq`. | |
| Otherwise, the right bound is the next index in `target_index`. | |
| Returns a 2-dim array where the first column is the absolute start index (including) and | |
| the second column is the absolute end index (excluding). | |
| If an element cannot be mapped, the start and end of the range becomes -1. | |
| !!! note | |
| Both index arrays must be increasing. Repeating values are allowed.""" | |
| range_starts_out = np.empty(len(target_index), dtype=int_) | |
| range_ends_out = np.empty(len(target_index), dtype=int_) | |
| to_j = 0 | |
| for i in range(len(target_index)): | |
| if i > 0 and target_index[i] < target_index[i - 1]: | |
| raise ValueError("Target index must be increasing") | |
| from_j = -1 | |
| for j in range(to_j, len(source_index)): | |
| if j > 0 and source_index[j] < source_index[j - 1]: | |
| raise ValueError("Array index must be increasing") | |
| found = False | |
| if target_freq is None: | |
| if before: | |
| if i == 0 and source_index[j] <= target_index[i]: | |
| found = True | |
| elif i > 0 and target_index[i - 1] < source_index[j] <= target_index[i]: | |
| found = True | |
| elif source_index[j] > target_index[i]: | |
| break | |
| else: | |
| if i == len(target_index) - 1 and target_index[i] <= source_index[j]: | |
| found = True | |
| elif i < len(target_index) - 1 and target_index[i] <= source_index[j] < target_index[i + 1]: | |
| found = True | |
| elif i < len(target_index) - 1 and source_index[j] >= target_index[i + 1]: | |
| break | |
| else: | |
| if before: | |
| if target_index[i] - target_freq < source_index[j] <= target_index[i]: | |
| found = True | |
| elif source_index[j] > target_index[i]: | |
| break | |
| else: | |
| if target_index[i] <= source_index[j] < target_index[i] + target_freq: | |
| found = True | |
| elif source_index[j] >= target_index[i] + target_freq: | |
| break | |
| if found: | |
| if from_j == -1: | |
| from_j = j | |
| to_j = j + 1 | |
| if from_j == -1: | |
| range_starts_out[i] = -1 | |
| range_ends_out[i] = -1 | |
| else: | |
| range_starts_out[i] = from_j | |
| range_ends_out[i] = to_j | |
| return range_starts_out, range_ends_out | |
| @register_jitted(cache=True) | |
| def map_bounds_to_source_ranges_nb( | |
| source_index: tp.Array1d, | |
| target_lbound_index: tp.Array1d, | |
| target_rbound_index: tp.Array1d, | |
| closed_lbound: bool = True, | |
| closed_rbound: bool = False, | |
| skip_not_found: bool = False, | |
| ) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Get the source bounds that correspond to the target bounds. | |
| Returns a 2-dim array where the first column is the absolute start index (including) nad | |
| the second column is the absolute end index (excluding). | |
| If an element cannot be mapped, the start and end of the range becomes -1. | |
| !!! note | |
| Both index arrays must be increasing. Repeating values are allowed.""" | |
| range_starts_out = np.empty(len(target_lbound_index), dtype=int_) | |
| range_ends_out = np.empty(len(target_lbound_index), dtype=int_) | |
| k = 0 | |
| to_j = 0 | |
| for i in range(len(target_lbound_index)): | |
| if i > 0 and target_lbound_index[i] < target_lbound_index[i - 1]: | |
| raise ValueError("Target left-bound index must be increasing") | |
| if i > 0 and target_rbound_index[i] < target_rbound_index[i - 1]: | |
| raise ValueError("Target right-bound index must be increasing") | |
| from_j = -1 | |
| for j in range(len(source_index)): | |
| if j > 0 and source_index[j] < source_index[j - 1]: | |
| raise ValueError("Array index must be increasing") | |
| found = False | |
| if closed_lbound and closed_rbound: | |
| if target_lbound_index[i] <= source_index[j] <= target_rbound_index[i]: | |
| found = True | |
| elif source_index[j] > target_rbound_index[i]: | |
| break | |
| elif closed_lbound: | |
| if target_lbound_index[i] <= source_index[j] < target_rbound_index[i]: | |
| found = True | |
| elif source_index[j] >= target_rbound_index[i]: | |
| break | |
| elif closed_rbound: | |
| if target_lbound_index[i] < source_index[j] <= target_rbound_index[i]: | |
| found = True | |
| elif source_index[j] > target_rbound_index[i]: | |
| break | |
| else: | |
| if target_lbound_index[i] < source_index[j] < target_rbound_index[i]: | |
| found = True | |
| elif source_index[j] >= target_rbound_index[i]: | |
| break | |
| if found: | |
| if from_j == -1: | |
| from_j = j | |
| to_j = j + 1 | |
| if skip_not_found: | |
| if from_j != -1: | |
| range_starts_out[k] = from_j | |
| range_ends_out[k] = to_j | |
| k += 1 | |
| else: | |
| if from_j == -1: | |
| range_starts_out[i] = -1 | |
| range_ends_out[i] = -1 | |
| else: | |
| range_starts_out[i] = from_j | |
| range_ends_out[i] = to_j | |
| if skip_not_found: | |
| return range_starts_out[:k], range_ends_out[:k] | |
| return range_starts_out, range_ends_out | |
| @register_jitted(cache=True) | |
| def resample_source_mask_nb( | |
| source_mask: tp.Array1d, | |
| source_index: tp.Array1d, | |
| target_index: tp.Array1d, | |
| source_freq: tp.Optional[tp.Scalar] = None, | |
| target_freq: tp.Optional[tp.Scalar] = None, | |
| ) -> tp.Array1d: | |
| """Resample a source mask to the target index. | |
| Becomes True only if the target bar is fully contained in the source bar. The source bar | |
| is represented by a non-interrupting sequence of True values in the source mask.""" | |
| out = np.full(len(target_index), False, dtype=np.bool_) | |
| from_j = 0 | |
| for i in range(len(target_index)): | |
| if i > 0 and target_index[i] < target_index[i - 1]: | |
| raise ValueError("Target index must be increasing") | |
| target_lbound = target_index[i] | |
| if target_freq is None: | |
| if i + 1 < len(target_index): | |
| target_rbound = target_index[i + 1] | |
| else: | |
| target_rbound = None | |
| else: | |
| target_rbound = target_index[i] + target_freq | |
| found_start = False | |
| for j in range(from_j, len(source_index)): | |
| if j > 0 and source_index[j] < source_index[j - 1]: | |
| raise ValueError("Source index must be increasing") | |
| source_lbound = source_index[j] | |
| if source_freq is None: | |
| if j + 1 < len(source_index): | |
| source_rbound = source_index[j + 1] | |
| else: | |
| source_rbound = None | |
| else: | |
| source_rbound = source_index[j] + source_freq | |
| if target_rbound is not None and target_rbound <= source_lbound: | |
| break | |
| if found_start or ( | |
| target_lbound >= source_lbound and (source_rbound is None or target_lbound < source_rbound) | |
| ): | |
| if not found_start: | |
| from_j = j | |
| found_start = True | |
| if not source_mask[j]: | |
| break | |
| if source_rbound is None or (target_rbound is not None and target_rbound <= source_rbound): | |
| out[i] = True | |
| break | |
| return out | |
| @register_jitted(cache=True) | |
| def last_before_target_index_nb( | |
| source_index: tp.Array1d, | |
| target_index: tp.Array1d, | |
| incl_source: bool = True, | |
| incl_target: bool = False, | |
| ) -> tp.Array1d: | |
| """For each source index, find the position of the last source index between the original | |
| source index and the corresponding target index.""" | |
| out = np.empty(len(source_index), dtype=int_) | |
| last_j = -1 | |
| for i in range(len(source_index)): | |
| if i > 0 and source_index[i] < source_index[i - 1]: | |
| raise ValueError("Source index must be increasing") | |
| if i > 0 and target_index[i] < target_index[i - 1]: | |
| raise ValueError("Target index must be increasing") | |
| if source_index[i] > target_index[i]: | |
| raise ValueError("Target index must be equal to or greater than source index") | |
| if last_j == -1: | |
| from_i = i + 1 | |
| else: | |
| from_i = last_j | |
| if incl_source: | |
| last_j = i | |
| else: | |
| last_j = -1 | |
| for j in range(from_i, len(source_index)): | |
| if source_index[j] < target_index[i]: | |
| last_j = j | |
| elif incl_target and source_index[j] == target_index[i]: | |
| last_j = j | |
| else: | |
| break | |
| out[i] = last_j | |
| return out | |
| </file> | |
| <file path="base/__init__.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Modules with base classes and utilities for pandas objects, such as broadcasting.""" | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from vectorbtpro.base.grouping import * | |
| from vectorbtpro.base.resampling import * | |
| from vectorbtpro.base.accessors import * | |
| from vectorbtpro.base.chunking import * | |
| from vectorbtpro.base.combining import * | |
| from vectorbtpro.base.decorators import * | |
| from vectorbtpro.base.flex_indexing import * | |
| from vectorbtpro.base.indexes import * | |
| from vectorbtpro.base.indexing import * | |
| from vectorbtpro.base.merging import * | |
| from vectorbtpro.base.preparing import * | |
| from vectorbtpro.base.reshaping import * | |
| from vectorbtpro.base.wrapping import * | |
| </file> | |
| <file path="base/accessors.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Custom Pandas accessors for base operations with Pandas objects.""" | |
| import ast | |
| import inspect | |
| import numpy as np | |
| import pandas as pd | |
| from pandas.api.types import is_scalar | |
| from pandas.core.groupby import GroupBy as PandasGroupBy | |
| from pandas.core.resample import Resampler as PandasResampler | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base import combining, reshaping, indexes | |
| from vectorbtpro.base.grouping.base import Grouper | |
| from vectorbtpro.base.indexes import IndexApplier | |
| from vectorbtpro.base.indexing import ( | |
| point_idxr_defaults, | |
| range_idxr_defaults, | |
| get_index_points, | |
| get_index_ranges, | |
| ) | |
| from vectorbtpro.base.resampling.base import Resampler | |
| from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer | |
| from vectorbtpro.utils.config import merge_dicts, resolve_dict, Configured | |
| from vectorbtpro.utils.decorators import hybrid_property, hybrid_method | |
| from vectorbtpro.utils.execution import Task, execute | |
| from vectorbtpro.utils.eval_ import evaluate | |
| from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods | |
| from vectorbtpro.utils.parsing import get_context_vars, get_func_arg_names | |
| from vectorbtpro.utils.template import substitute_templates | |
| from vectorbtpro.utils.warnings_ import warn | |
| if tp.TYPE_CHECKING: | |
| from vectorbtpro.data.base import Data as DataT | |
| else: | |
| DataT = "Data" | |
| if tp.TYPE_CHECKING: | |
| from vectorbtpro.generic.splitting.base import Splitter as SplitterT | |
| else: | |
| SplitterT = "Splitter" | |
| __all__ = ["BaseIDXAccessor", "BaseAccessor", "BaseSRAccessor", "BaseDFAccessor"] | |
| BaseIDXAccessorT = tp.TypeVar("BaseIDXAccessorT", bound="BaseIDXAccessor") | |
| class BaseIDXAccessor(Configured, IndexApplier): | |
| """Accessor on top of Index. | |
| Accessible via `pd.Index.vbt` and all child accessors.""" | |
| def __init__(self, obj: tp.Index, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs) -> None: | |
| checks.assert_instance_of(obj, pd.Index) | |
| Configured.__init__(self, obj=obj, freq=freq, **kwargs) | |
| self._obj = obj | |
| self._freq = freq | |
| @property | |
| def obj(self) -> tp.Index: | |
| """Pandas object.""" | |
| return self._obj | |
| def get(self) -> tp.Index: | |
| """Get `IDXAccessor.obj`.""" | |
| return self.obj | |
| # ############# Index ############# # | |
| def to_ns(self) -> tp.Array1d: | |
| """Convert index to an 64-bit integer array. | |
| Timestamps will be converted to nanoseconds.""" | |
| return dt.to_ns(self.obj) | |
| def to_period(self, freq: tp.FrequencyLike, shift: bool = False) -> pd.PeriodIndex: | |
| """Convert index to period.""" | |
| index = self.obj | |
| if isinstance(index, pd.DatetimeIndex): | |
| index = index.tz_localize(None).to_period(freq) | |
| if shift: | |
| index = index.shift() | |
| if not isinstance(index, pd.PeriodIndex): | |
| raise TypeError(f"Cannot convert index of type {type(index)} to period") | |
| return index | |
| def to_period_ts(self, *args, **kwargs) -> pd.DatetimeIndex: | |
| """Convert index to period and then to timestamp.""" | |
| new_index = self.to_period(*args, **kwargs).to_timestamp() | |
| if self.obj.tz is not None: | |
| new_index = new_index.tz_localize(self.obj.tz) | |
| return new_index | |
| def to_period_ns(self, *args, **kwargs) -> tp.Array1d: | |
| """Convert index to period and then to an 64-bit integer array. | |
| Timestamps will be converted to nanoseconds.""" | |
| return dt.to_ns(self.to_period_ts(*args, **kwargs)) | |
| @classmethod | |
| def from_values(cls, *args, **kwargs) -> tp.Index: | |
| """See `vectorbtpro.base.indexes.index_from_values`.""" | |
| return indexes.index_from_values(*args, **kwargs) | |
| def repeat(self, *args, **kwargs) -> tp.Index: | |
| """See `vectorbtpro.base.indexes.repeat_index`.""" | |
| return indexes.repeat_index(self.obj, *args, **kwargs) | |
| def tile(self, *args, **kwargs) -> tp.Index: | |
| """See `vectorbtpro.base.indexes.tile_index`.""" | |
| return indexes.tile_index(self.obj, *args, **kwargs) | |
| @hybrid_method | |
| def stack( | |
| cls_or_self, | |
| *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], | |
| on_top: bool = False, | |
| **kwargs, | |
| ) -> tp.Index: | |
| """See `vectorbtpro.base.indexes.stack_indexes`. | |
| Set `on_top` to True to stack the second index on top of this one.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| if on_top: | |
| objs = (*others, cls_or_self.obj) | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return indexes.stack_indexes(*objs, **kwargs) | |
| @hybrid_method | |
| def combine( | |
| cls_or_self, | |
| *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], | |
| on_top: bool = False, | |
| **kwargs, | |
| ) -> tp.Index: | |
| """See `vectorbtpro.base.indexes.combine_indexes`. | |
| Set `on_top` to True to stack the second index on top of this one.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| if on_top: | |
| objs = (*others, cls_or_self.obj) | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return indexes.combine_indexes(*objs, **kwargs) | |
| @hybrid_method | |
| def concat(cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], **kwargs) -> tp.Index: | |
| """See `vectorbtpro.base.indexes.concat_indexes`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return indexes.concat_indexes(*objs, **kwargs) | |
| def apply_to_index( | |
| self: BaseIDXAccessorT, | |
| apply_func: tp.Callable, | |
| *args, | |
| **kwargs, | |
| ) -> tp.Index: | |
| return self.replace(obj=apply_func(self.obj, *args, **kwargs)).obj | |
| def align_to(self, *args, **kwargs) -> tp.IndexSlice: | |
| """See `vectorbtpro.base.indexes.align_index_to`.""" | |
| return indexes.align_index_to(self.obj, *args, **kwargs) | |
| @hybrid_method | |
| def align( | |
| cls_or_self, | |
| *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], | |
| **kwargs, | |
| ) -> tp.Tuple[tp.IndexSlice, ...]: | |
| """See `vectorbtpro.base.indexes.align_indexes`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return indexes.align_indexes(*objs, **kwargs) | |
| def cross_with(self, *args, **kwargs) -> tp.Tuple[tp.IndexSlice, tp.IndexSlice]: | |
| """See `vectorbtpro.base.indexes.cross_index_with`.""" | |
| return indexes.cross_index_with(self.obj, *args, **kwargs) | |
| @hybrid_method | |
| def cross( | |
| cls_or_self, | |
| *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], | |
| **kwargs, | |
| ) -> tp.Tuple[tp.IndexSlice, ...]: | |
| """See `vectorbtpro.base.indexes.cross_indexes`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return indexes.cross_indexes(*objs, **kwargs) | |
| x = cross | |
| def find_first_occurrence(self, *args, **kwargs) -> int: | |
| """See `vectorbtpro.base.indexes.find_first_occurrence`.""" | |
| return indexes.find_first_occurrence(self.obj, *args, **kwargs) | |
| # ############# Frequency ############# # | |
| @hybrid_method | |
| def get_freq( | |
| cls_or_self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| **kwargs, | |
| ) -> tp.Union[None, float, tp.PandasFrequency]: | |
| """Index frequency as `pd.Timedelta` or None if it cannot be converted.""" | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| if not isinstance(cls_or_self, type): | |
| if index is None: | |
| index = cls_or_self.obj | |
| if freq is None: | |
| freq = cls_or_self._freq | |
| else: | |
| checks.assert_not_none(index, arg_name="index") | |
| if freq is None: | |
| freq = wrapping_cfg["freq"] | |
| try: | |
| return dt.infer_index_freq(index, freq=freq, **kwargs) | |
| except Exception as e: | |
| return None | |
| @property | |
| def freq(self) -> tp.Optional[tp.PandasFrequency]: | |
| """`BaseIDXAccessor.get_freq` with date offsets and integer frequencies not allowed.""" | |
| return self.get_freq(allow_offset=True, allow_numeric=False) | |
| @property | |
| def ns_freq(self) -> tp.Optional[int]: | |
| """Convert frequency to a 64-bit integer. | |
| Timedelta will be converted to nanoseconds.""" | |
| freq = self.get_freq(allow_offset=False, allow_numeric=True) | |
| if freq is not None: | |
| freq = dt.to_ns(dt.to_timedelta64(freq)) | |
| return freq | |
| @property | |
| def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]: | |
| """Index frequency of any type.""" | |
| return self.get_freq() | |
| @hybrid_method | |
| def get_periods(cls_or_self, index: tp.Optional[tp.Index] = None) -> int: | |
| """Get the number of periods in the index, without taking into account its datetime-like properties.""" | |
| if not isinstance(cls_or_self, type): | |
| if index is None: | |
| index = cls_or_self.obj | |
| else: | |
| checks.assert_not_none(index, arg_name="index") | |
| return len(index) | |
| @property | |
| def periods(self) -> int: | |
| """`BaseIDXAccessor.get_periods` with default arguments.""" | |
| return len(self.obj) | |
| @hybrid_method | |
| def get_dt_periods( | |
| cls_or_self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.PandasFrequency] = None, | |
| ) -> float: | |
| """Get the number of periods in the index, taking into account its datetime-like properties.""" | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| if not isinstance(cls_or_self, type): | |
| if index is None: | |
| index = cls_or_self.obj | |
| else: | |
| checks.assert_not_none(index, arg_name="index") | |
| if isinstance(index, pd.DatetimeIndex): | |
| freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=True, allow_numeric=False) | |
| if freq is not None: | |
| if not isinstance(freq, pd.Timedelta): | |
| freq = dt.to_timedelta(freq, approximate=True) | |
| return (index[-1] - index[0]) / freq + 1 | |
| if not wrapping_cfg["silence_warnings"]: | |
| warn( | |
| "Couldn't parse the frequency of index. Pass it as `freq` or " | |
| "define it globally under `settings.wrapping`." | |
| ) | |
| if checks.is_number(index[0]) and checks.is_number(index[-1]): | |
| freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=False, allow_numeric=True) | |
| if checks.is_number(freq): | |
| return (index[-1] - index[0]) / freq + 1 | |
| return index[-1] - index[0] + 1 | |
| if not wrapping_cfg["silence_warnings"]: | |
| warn("Index is neither datetime-like nor integer") | |
| return cls_or_self.get_periods(index=index) | |
| @property | |
| def dt_periods(self) -> float: | |
| """`BaseIDXAccessor.get_dt_periods` with default arguments.""" | |
| return self.get_dt_periods() | |
| def arr_to_timedelta( | |
| self, | |
| a: tp.MaybeArray, | |
| to_pd: bool = False, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.Union[pd.Index, tp.MaybeArray]: | |
| """Convert array to duration using `BaseIDXAccessor.freq`.""" | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| if silence_warnings is None: | |
| silence_warnings = wrapping_cfg["silence_warnings"] | |
| freq = self.freq | |
| if freq is None: | |
| if not silence_warnings: | |
| warn( | |
| "Couldn't parse the frequency of index. Pass it as `freq` or " | |
| "define it globally under `settings.wrapping`." | |
| ) | |
| return a | |
| if not isinstance(freq, pd.Timedelta): | |
| freq = dt.to_timedelta(freq, approximate=True) | |
| if to_pd: | |
| out = pd.to_timedelta(a * freq) | |
| else: | |
| out = a * freq | |
| return out | |
| # ############# Grouping ############# # | |
| def get_grouper(self, by: tp.AnyGroupByLike, groupby_kwargs: tp.KwargsLike = None, **kwargs) -> Grouper: | |
| """Get an index grouper of type `vectorbtpro.base.grouping.base.Grouper`. | |
| Argument `by` can be a grouper itself, an instance of Pandas `GroupBy`, | |
| an instance of Pandas `Resampler`, but also any supported input to any of them | |
| such as a frequency or an array of indices. | |
| Keyword arguments `groupby_kwargs` are passed to the Pandas methods `groupby` and `resample`, | |
| while `**kwargs` are passed to initialize `vectorbtpro.base.grouping.base.Grouper`.""" | |
| if groupby_kwargs is None: | |
| groupby_kwargs = {} | |
| if isinstance(by, Grouper): | |
| if len(kwargs) > 0: | |
| return by.replace(**kwargs) | |
| return by | |
| if isinstance(by, (PandasGroupBy, PandasResampler)): | |
| return Grouper.from_pd_group_by(by, **kwargs) | |
| try: | |
| return Grouper(index=self.obj, group_by=by, **kwargs) | |
| except Exception as e: | |
| pass | |
| if isinstance(self.obj, pd.DatetimeIndex): | |
| try: | |
| return Grouper(index=self.obj, group_by=self.to_period(dt.to_freq(by)), **kwargs) | |
| except Exception as e: | |
| pass | |
| try: | |
| pd_group_by = pd.Series(index=self.obj, dtype=object).resample(dt.to_freq(by), **groupby_kwargs) | |
| return Grouper.from_pd_group_by(pd_group_by, **kwargs) | |
| except Exception as e: | |
| pass | |
| pd_group_by = pd.Series(index=self.obj, dtype=object).groupby(by, axis=0, **groupby_kwargs) | |
| return Grouper.from_pd_group_by(pd_group_by, **kwargs) | |
| def get_resampler( | |
| self, | |
| rule: tp.AnyRuleLike, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| resample_kwargs: tp.KwargsLike = None, | |
| return_pd_resampler: bool = False, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.Union[Resampler, tp.PandasResampler]: | |
| """Get an index resampler of type `vectorbtpro.base.resampling.base.Resampler`.""" | |
| if checks.is_frequency_like(rule): | |
| try: | |
| rule = dt.to_freq(rule) | |
| is_td = True | |
| except Exception as e: | |
| is_td = False | |
| if is_td: | |
| resample_kwargs = merge_dicts( | |
| dict(closed="left", label="left"), | |
| resample_kwargs, | |
| ) | |
| rule = pd.Series(index=self.obj, dtype=object).resample(rule, **resolve_dict(resample_kwargs)) | |
| if isinstance(rule, PandasResampler): | |
| if return_pd_resampler: | |
| return rule | |
| if silence_warnings is None: | |
| silence_warnings = True | |
| rule = Resampler.from_pd_resampler(rule, source_freq=self.freq, silence_warnings=silence_warnings) | |
| if return_pd_resampler: | |
| raise TypeError("Cannot convert Resampler to Pandas Resampler") | |
| if checks.is_dt_like(rule) or checks.is_iterable(rule): | |
| rule = dt.prepare_dt_index(rule) | |
| rule = Resampler( | |
| source_index=self.obj, | |
| target_index=rule, | |
| source_freq=self.freq, | |
| target_freq=freq, | |
| silence_warnings=silence_warnings, | |
| ) | |
| if isinstance(rule, Resampler): | |
| if freq is not None: | |
| rule = rule.replace(target_freq=freq) | |
| return rule | |
| raise ValueError(f"Cannot build Resampler from {rule}") | |
| # ############# Points and ranges ############# # | |
| def get_points(self, *args, **kwargs) -> tp.Array1d: | |
| """See `vectorbtpro.base.indexing.get_index_points`.""" | |
| return get_index_points(self.obj, *args, **kwargs) | |
| def get_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """See `vectorbtpro.base.indexing.get_index_ranges`.""" | |
| return get_index_ranges(self.obj, self.any_freq, *args, **kwargs) | |
| # ############# Splitting ############# # | |
| def split(self, *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, **kwargs) -> tp.Any: | |
| """Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_take`. | |
| !!! note | |
| Splits Pandas object, not accessor!""" | |
| from vectorbtpro.generic.splitting.base import Splitter | |
| if splitter_cls is None: | |
| splitter_cls = Splitter | |
| return splitter_cls.split_and_take(self.obj, self.obj, *args, **kwargs) | |
| def split_apply( | |
| self, | |
| apply_func: tp.Callable, | |
| *args, | |
| splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, | |
| **kwargs, | |
| ) -> tp.Any: | |
| """Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`. | |
| !!! note | |
| Splits Pandas object, not accessor!""" | |
| from vectorbtpro.generic.splitting.base import Splitter, Takeable | |
| if splitter_cls is None: | |
| splitter_cls = Splitter | |
| return splitter_cls.split_and_apply(self.obj, apply_func, Takeable(self.obj), *args, **kwargs) | |
| # ############# Chunking ############# # | |
| def chunk( | |
| self: BaseIDXAccessorT, | |
| min_size: tp.Optional[int] = None, | |
| n_chunks: tp.Union[None, int, str] = None, | |
| chunk_len: tp.Union[None, int, str] = None, | |
| chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None, | |
| select: bool = False, | |
| return_chunk_meta: bool = False, | |
| ) -> tp.Iterator[tp.Union[tp.Index, tp.Tuple[ChunkMeta, tp.Index]]]: | |
| """Chunk this instance. | |
| If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise. | |
| For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`. | |
| !!! note | |
| Splits Pandas object, not accessor!""" | |
| if chunk_meta is None: | |
| chunk_meta = iter_chunk_meta( | |
| size=len(self.obj), | |
| min_size=min_size, | |
| n_chunks=n_chunks, | |
| chunk_len=chunk_len | |
| ) | |
| for _chunk_meta in chunk_meta: | |
| if select: | |
| array_taker = ArraySelector() | |
| else: | |
| array_taker = ArraySlicer() | |
| if return_chunk_meta: | |
| yield _chunk_meta, array_taker.take(self.obj, _chunk_meta) | |
| else: | |
| yield array_taker.take(self.obj, _chunk_meta) | |
| def chunk_apply( | |
| self: BaseIDXAccessorT, | |
| apply_func: tp.Union[str, tp.Callable], | |
| *args, | |
| chunk_kwargs: tp.KwargsLike = None, | |
| execute_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.MergeableResults: | |
| """Chunk this instance and apply a function to each chunk. | |
| If `apply_func` is a string, becomes the method name. | |
| For arguments related to chunking, see `Wrapping.chunk`. | |
| !!! note | |
| Splits Pandas object, not accessor!""" | |
| if isinstance(apply_func, str): | |
| apply_func = getattr(type(self), apply_func) | |
| if chunk_kwargs is None: | |
| chunk_arg_names = set(get_func_arg_names(self.chunk)) | |
| chunk_kwargs = {} | |
| for k in list(kwargs.keys()): | |
| if k in chunk_arg_names: | |
| chunk_kwargs[k] = kwargs.pop(k) | |
| if execute_kwargs is None: | |
| execute_kwargs = {} | |
| chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs) | |
| tasks = [] | |
| keys = [] | |
| for _chunk_meta, chunk in chunks: | |
| tasks.append(Task(apply_func, chunk, *args, **kwargs)) | |
| keys.append(get_chunk_meta_key(_chunk_meta)) | |
| keys = pd.Index(keys, name="chunk_indices") | |
| return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs) | |
| BaseAccessorT = tp.TypeVar("BaseAccessorT", bound="BaseAccessor") | |
| @attach_binary_magic_methods(lambda self, other, np_func: self.combine(other, combine_func=np_func)) | |
| @attach_unary_magic_methods(lambda self, np_func: self.apply(apply_func=np_func)) | |
| class BaseAccessor(Wrapping): | |
| """Accessor on top of Series and DataFrames. | |
| Accessible via `pd.Series.vbt` and `pd.DataFrame.vbt`, and all child accessors. | |
| Series is just a DataFrame with one column, hence to avoid defining methods exclusively for 1-dim data, | |
| we will convert any Series to a DataFrame and perform matrix computation on it. Afterwards, | |
| by using `BaseAccessor.wrapper`, we will convert the 2-dim output back to a Series. | |
| `**kwargs` will be passed to `vectorbtpro.base.wrapping.ArrayWrapper`. | |
| !!! note | |
| When using magic methods, ensure that `.vbt` is called on the operand on the left | |
| if the other operand is an array. | |
| Accessors do not utilize caching. | |
| Grouping is only supported by the methods that accept the `group_by` argument. | |
| Usage: | |
| * Build a symmetric matrix: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> # vectorbtpro.base.accessors.BaseAccessor.make_symmetric | |
| >>> pd.Series([1, 2, 3]).vbt.make_symmetric() | |
| 0 1 2 | |
| 0 1.0 2.0 3.0 | |
| 1 2.0 NaN NaN | |
| 2 3.0 NaN NaN | |
| ``` | |
| * Broadcast pandas objects: | |
| ```pycon | |
| >>> sr = pd.Series([1]) | |
| >>> df = pd.DataFrame([1, 2, 3]) | |
| >>> vbt.base.reshaping.broadcast_to(sr, df) | |
| 0 | |
| 0 1 | |
| 1 1 | |
| 2 1 | |
| >>> sr.vbt.broadcast_to(df) | |
| 0 | |
| 0 1 | |
| 1 1 | |
| 2 1 | |
| ``` | |
| * Many methods such as `BaseAccessor.broadcast` are both class and instance methods: | |
| ```pycon | |
| >>> from vectorbtpro.base.accessors import BaseAccessor | |
| >>> # Same as sr.vbt.broadcast(df) | |
| >>> new_sr, new_df = BaseAccessor.broadcast(sr, df) | |
| >>> new_sr | |
| 0 | |
| 0 1 | |
| 1 1 | |
| 2 1 | |
| >>> new_df | |
| 0 | |
| 0 1 | |
| 1 2 | |
| 2 3 | |
| ``` | |
| * Instead of explicitly importing `BaseAccessor` or any other accessor, we can use `pd_acc` instead: | |
| ```pycon | |
| >>> vbt.pd_acc.broadcast(sr, df) | |
| >>> new_sr | |
| 0 | |
| 0 1 | |
| 1 1 | |
| 2 1 | |
| >>> new_df | |
| 0 | |
| 0 1 | |
| 1 2 | |
| 2 3 | |
| ``` | |
| * `BaseAccessor` implements arithmetic (such as `+`), comparison (such as `>`) and | |
| logical operators (such as `&`) by forwarding the operation to `BaseAccessor.combine`: | |
| ```pycon | |
| >>> sr.vbt + df | |
| 0 | |
| 0 2 | |
| 1 3 | |
| 2 4 | |
| ``` | |
| Many interesting use cases can be implemented this way. | |
| * For example, let's compare an array with 3 different thresholds: | |
| ```pycon | |
| >>> df.vbt > vbt.Param(np.arange(3), name='threshold') | |
| threshold 0 1 2 | |
| a2 b2 c2 a2 b2 c2 a2 b2 c2 | |
| x2 True True True False True True False False True | |
| y2 True True True True True True True True True | |
| z2 True True True True True True True True True | |
| ``` | |
| * The same using the broadcasting mechanism: | |
| ```pycon | |
| >>> df.vbt > vbt.Param(np.arange(3), name='threshold') | |
| threshold 0 1 2 | |
| a2 b2 c2 a2 b2 c2 a2 b2 c2 | |
| x2 True True True False True True False False True | |
| y2 True True True True True True True True True | |
| z2 True True True True True True True True True | |
| ``` | |
| """ | |
| @classmethod | |
| def resolve_row_stack_kwargs( | |
| cls: tp.Type[BaseAccessorT], | |
| *objs: tp.MaybeTuple[BaseAccessorT], | |
| **kwargs, | |
| ) -> tp.Kwargs: | |
| """Resolve keyword arguments for initializing `BaseAccessor` after stacking along rows.""" | |
| if "obj" not in kwargs: | |
| kwargs["obj"] = kwargs["wrapper"].row_stack_arrs( | |
| *[obj.obj for obj in objs], | |
| group_by=False, | |
| wrap=False, | |
| ) | |
| return kwargs | |
| @classmethod | |
| def resolve_column_stack_kwargs( | |
| cls: tp.Type[BaseAccessorT], | |
| *objs: tp.MaybeTuple[BaseAccessorT], | |
| reindex_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.Kwargs: | |
| """Resolve keyword arguments for initializing `BaseAccessor` after stacking along columns.""" | |
| if "obj" not in kwargs: | |
| kwargs["obj"] = kwargs["wrapper"].column_stack_arrs( | |
| *[obj.obj for obj in objs], | |
| reindex_kwargs=reindex_kwargs, | |
| group_by=False, | |
| wrap=False, | |
| ) | |
| return kwargs | |
| @hybrid_method | |
| def row_stack( | |
| cls_or_self: tp.MaybeType[BaseAccessorT], | |
| *objs: tp.MaybeTuple[BaseAccessorT], | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> BaseAccessorT: | |
| """Stack multiple `BaseAccessor` instances along rows. | |
| Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.""" | |
| if not isinstance(cls_or_self, type): | |
| objs = (cls_or_self, *objs) | |
| cls = type(cls_or_self) | |
| else: | |
| cls = cls_or_self | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| for obj in objs: | |
| if not checks.is_instance_of(obj, BaseAccessor): | |
| raise TypeError("Each object to be merged must be an instance of BaseAccessor") | |
| if wrapper_kwargs is None: | |
| wrapper_kwargs = {} | |
| if "wrapper" in kwargs and kwargs["wrapper"] is not None: | |
| wrapper = kwargs["wrapper"] | |
| if len(wrapper_kwargs) > 0: | |
| wrapper = wrapper.replace(**wrapper_kwargs) | |
| else: | |
| wrapper = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) | |
| kwargs["wrapper"] = wrapper | |
| kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) | |
| kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) | |
| if kwargs["wrapper"].ndim == 1: | |
| return cls.sr_accessor_cls(**kwargs) | |
| return cls.df_accessor_cls(**kwargs) | |
| @hybrid_method | |
| def column_stack( | |
| cls_or_self: tp.MaybeType[BaseAccessorT], | |
| *objs: tp.MaybeTuple[BaseAccessorT], | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| reindex_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> BaseAccessorT: | |
| """Stack multiple `BaseAccessor` instances along columns. | |
| Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.""" | |
| if not isinstance(cls_or_self, type): | |
| objs = (cls_or_self, *objs) | |
| cls = type(cls_or_self) | |
| else: | |
| cls = cls_or_self | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| for obj in objs: | |
| if not checks.is_instance_of(obj, BaseAccessor): | |
| raise TypeError("Each object to be merged must be an instance of BaseAccessor") | |
| if wrapper_kwargs is None: | |
| wrapper_kwargs = {} | |
| if "wrapper" in kwargs and kwargs["wrapper"] is not None: | |
| wrapper = kwargs["wrapper"] | |
| if len(wrapper_kwargs) > 0: | |
| wrapper = wrapper.replace(**wrapper_kwargs) | |
| else: | |
| wrapper = ArrayWrapper.column_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) | |
| kwargs["wrapper"] = wrapper | |
| kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) | |
| kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) | |
| return cls.df_accessor_cls(**kwargs) | |
| def __init__( | |
| self, | |
| wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], | |
| obj: tp.Optional[tp.ArrayLike] = None, | |
| **kwargs, | |
| ) -> None: | |
| if len(kwargs) > 0: | |
| wrapper_kwargs, kwargs = ArrayWrapper.extract_init_kwargs(**kwargs) | |
| else: | |
| wrapper_kwargs, kwargs = {}, {} | |
| if not isinstance(wrapper, ArrayWrapper): | |
| if obj is not None: | |
| raise ValueError("Must either provide wrapper and object, or only object") | |
| wrapper, obj = ArrayWrapper.from_obj(wrapper, **wrapper_kwargs), wrapper | |
| else: | |
| if obj is None: | |
| raise ValueError("Must either provide wrapper and object, or only object") | |
| if len(wrapper_kwargs) > 0: | |
| wrapper = wrapper.replace(**wrapper_kwargs) | |
| Wrapping.__init__(self, wrapper, obj=obj, **kwargs) | |
| self._obj = obj | |
| def __call__(self: BaseAccessorT, **kwargs) -> BaseAccessorT: | |
| """Allows passing arguments to the initializer.""" | |
| return self.replace(**kwargs) | |
| @hybrid_property | |
| def sr_accessor_cls(cls_or_self) -> tp.Type["BaseSRAccessor"]: | |
| """Accessor class for `pd.Series`.""" | |
| return BaseSRAccessor | |
| @hybrid_property | |
| def df_accessor_cls(cls_or_self) -> tp.Type["BaseDFAccessor"]: | |
| """Accessor class for `pd.DataFrame`.""" | |
| return BaseDFAccessor | |
| def indexing_func(self: BaseAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> BaseAccessorT: | |
| """Perform indexing on `BaseAccessor`.""" | |
| if wrapper_meta is None: | |
| wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) | |
| new_obj = ArrayWrapper.select_from_flex_array( | |
| self._obj, | |
| row_idxs=wrapper_meta["row_idxs"], | |
| col_idxs=wrapper_meta["col_idxs"], | |
| rows_changed=wrapper_meta["rows_changed"], | |
| columns_changed=wrapper_meta["columns_changed"], | |
| ) | |
| if checks.is_series(new_obj): | |
| return self.replace(cls_=self.sr_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj) | |
| return self.replace(cls_=self.df_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj) | |
| def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None: | |
| """Perform indexing setter on `BaseAccessor`.""" | |
| pd_indexing_setter_func(self._obj) | |
| @property | |
| def obj(self) -> tp.SeriesFrame: | |
| """Pandas object.""" | |
| if isinstance(self._obj, (pd.Series, pd.DataFrame)): | |
| if self._obj.shape == self.wrapper.shape: | |
| if self._obj.index is self.wrapper.index: | |
| if isinstance(self._obj, pd.Series) and self._obj.name == self.wrapper.name: | |
| return self._obj | |
| if isinstance(self._obj, pd.DataFrame) and self._obj.columns is self.wrapper.columns: | |
| return self._obj | |
| return self.wrapper.wrap(self._obj, group_by=False) | |
| def get(self, key: tp.Optional[tp.Hashable] = None, default: tp.Optional[tp.Any] = None) -> tp.SeriesFrame: | |
| """Get `BaseAccessor.obj`.""" | |
| if key is None: | |
| return self.obj | |
| return self.obj.get(key, default=default) | |
| @property | |
| def unwrapped(self) -> tp.SeriesFrame: | |
| return self.obj | |
| @hybrid_method | |
| def should_wrap(cls_or_self) -> bool: | |
| return False | |
| @hybrid_property | |
| def ndim(cls_or_self) -> tp.Optional[int]: | |
| """Number of dimensions in the object. | |
| 1 -> Series, 2 -> DataFrame.""" | |
| if isinstance(cls_or_self, type): | |
| return None | |
| return cls_or_self.obj.ndim | |
| @hybrid_method | |
| def is_series(cls_or_self) -> bool: | |
| """Whether the object is a Series.""" | |
| if isinstance(cls_or_self, type): | |
| raise NotImplementedError | |
| return isinstance(cls_or_self.obj, pd.Series) | |
| @hybrid_method | |
| def is_frame(cls_or_self) -> bool: | |
| """Whether the object is a DataFrame.""" | |
| if isinstance(cls_or_self, type): | |
| raise NotImplementedError | |
| return isinstance(cls_or_self.obj, pd.DataFrame) | |
| @classmethod | |
| def resolve_shape(cls, shape: tp.ShapeLike) -> tp.Shape: | |
| """Resolve shape.""" | |
| shape_2d = reshaping.to_2d_shape(shape) | |
| try: | |
| if cls.is_series() and shape_2d[1] > 1: | |
| raise ValueError("Use DataFrame accessor") | |
| except NotImplementedError: | |
| pass | |
| return shape_2d | |
| # ############# Creation ############# # | |
| @classmethod | |
| def empty(cls, shape: tp.Shape, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame: | |
| """Generate an empty Series/DataFrame of shape `shape` and fill with `fill_value`.""" | |
| if not isinstance(shape, tuple) or (isinstance(shape, tuple) and len(shape) == 1): | |
| return pd.Series(np.full(shape, fill_value), **kwargs) | |
| return pd.DataFrame(np.full(shape, fill_value), **kwargs) | |
| @classmethod | |
| def empty_like(cls, other: tp.SeriesFrame, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame: | |
| """Generate an empty Series/DataFrame like `other` and fill with `fill_value`.""" | |
| if checks.is_series(other): | |
| return cls.empty(other.shape, fill_value=fill_value, index=other.index, name=other.name, **kwargs) | |
| return cls.empty(other.shape, fill_value=fill_value, index=other.index, columns=other.columns, **kwargs) | |
| # ############# Indexes ############# # | |
| def apply_to_index( | |
| self: BaseAccessorT, | |
| *args, | |
| wrap: bool = False, | |
| **kwargs, | |
| ) -> tp.Union[BaseAccessorT, tp.SeriesFrame]: | |
| """See `vectorbtpro.base.wrapping.Wrapping.apply_to_index`. | |
| !!! note | |
| If `wrap` is False, returns Pandas object, not accessor!""" | |
| result = Wrapping.apply_to_index(self, *args, **kwargs) | |
| if wrap: | |
| return result | |
| return result.obj | |
| # ############# Setting ############# # | |
| def set( | |
| self, | |
| value_or_func: tp.Union[tp.MaybeArray, tp.Callable], | |
| *args, | |
| inplace: bool = False, | |
| columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None, | |
| template_context: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.Optional[tp.SeriesFrame]: | |
| """Set value at each index point using `vectorbtpro.base.indexing.get_index_points`. | |
| If `value_or_func` is a function, selects all keyword arguments that were not passed | |
| to the `get_index_points` method, substitutes any templates, and passes everything to the function. | |
| As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index), | |
| `index_point` (absolute position in the index), `wrapper`, and `obj`.""" | |
| if inplace: | |
| obj = self.obj | |
| else: | |
| obj = self.obj.copy() | |
| index_points = get_index_points(self.wrapper.index, **kwargs) | |
| if callable(value_or_func): | |
| func_kwargs = {k: v for k, v in kwargs.items() if k not in point_idxr_defaults} | |
| template_context = merge_dicts(kwargs, template_context) | |
| else: | |
| func_kwargs = None | |
| if callable(value_or_func): | |
| for i in range(len(index_points)): | |
| _template_context = merge_dicts( | |
| dict( | |
| i=i, | |
| index_point=index_points[i], | |
| index_points=index_points, | |
| wrapper=self.wrapper, | |
| obj=self.obj, | |
| columns=columns, | |
| args=args, | |
| kwargs=kwargs, | |
| ), | |
| template_context, | |
| ) | |
| _func_args = substitute_templates(args, _template_context, eval_id="func_args") | |
| _func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs") | |
| v = value_or_func(*_func_args, **_func_kwargs) | |
| if self.is_series() or columns is None: | |
| obj.iloc[index_points[i]] = v | |
| elif is_scalar(columns): | |
| obj.iloc[index_points[i], obj.columns.get_indexer([columns])[0]] = v | |
| else: | |
| obj.iloc[index_points[i], obj.columns.get_indexer(columns)] = v | |
| elif checks.is_sequence(value_or_func) and not is_scalar(value_or_func): | |
| if self.is_series(): | |
| obj.iloc[index_points] = reshaping.to_1d_array(value_or_func) | |
| elif columns is None: | |
| obj.iloc[index_points] = reshaping.to_2d_array(value_or_func) | |
| elif is_scalar(columns): | |
| obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = reshaping.to_1d_array(value_or_func) | |
| else: | |
| obj.iloc[index_points, obj.columns.get_indexer(columns)] = reshaping.to_2d_array(value_or_func) | |
| else: | |
| if self.is_series() or columns is None: | |
| obj.iloc[index_points] = value_or_func | |
| elif is_scalar(columns): | |
| obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = value_or_func | |
| else: | |
| obj.iloc[index_points, obj.columns.get_indexer(columns)] = value_or_func | |
| if inplace: | |
| return None | |
| return obj | |
| def set_between( | |
| self, | |
| value_or_func: tp.Union[tp.MaybeArray, tp.Callable], | |
| *args, | |
| inplace: bool = False, | |
| columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None, | |
| template_context: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.Optional[tp.SeriesFrame]: | |
| """Set value at each index range using `vectorbtpro.base.indexing.get_index_ranges`. | |
| If `value_or_func` is a function, selects all keyword arguments that were not passed | |
| to the `get_index_points` method, substitutes any templates, and passes everything to the function. | |
| As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index), | |
| `index_slice` (absolute slice of the index), `wrapper`, and `obj`.""" | |
| if inplace: | |
| obj = self.obj | |
| else: | |
| obj = self.obj.copy() | |
| index_ranges = get_index_ranges(self.wrapper.index, **kwargs) | |
| if callable(value_or_func): | |
| func_kwargs = {k: v for k, v in kwargs.items() if k not in range_idxr_defaults} | |
| template_context = merge_dicts(kwargs, template_context) | |
| else: | |
| func_kwargs = None | |
| for i in range(len(index_ranges[0])): | |
| if callable(value_or_func): | |
| _template_context = merge_dicts( | |
| dict( | |
| i=i, | |
| index_slice=slice(index_ranges[0][i], index_ranges[1][i]), | |
| range_starts=index_ranges[0], | |
| range_ends=index_ranges[1], | |
| wrapper=self.wrapper, | |
| obj=self.obj, | |
| columns=columns, | |
| args=args, | |
| kwargs=kwargs, | |
| ), | |
| template_context, | |
| ) | |
| _func_args = substitute_templates(args, _template_context, eval_id="func_args") | |
| _func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs") | |
| v = value_or_func(*_func_args, **_func_kwargs) | |
| elif checks.is_sequence(value_or_func) and not isinstance(value_or_func, str): | |
| v = value_or_func[i] | |
| else: | |
| v = value_or_func | |
| if self.is_series() or columns is None: | |
| obj.iloc[index_ranges[0][i] : index_ranges[1][i]] = v | |
| elif is_scalar(columns): | |
| obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer([columns])[0]] = v | |
| else: | |
| obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer(columns)] = v | |
| if inplace: | |
| return None | |
| return obj | |
| # ############# Reshaping ############# # | |
| def to_1d_array(self) -> tp.Array1d: | |
| """See `vectorbtpro.base.reshaping.to_1d` with `raw` set to True.""" | |
| return reshaping.to_1d_array(self.obj) | |
| def to_2d_array(self) -> tp.Array2d: | |
| """See `vectorbtpro.base.reshaping.to_2d` with `raw` set to True.""" | |
| return reshaping.to_2d_array(self.obj) | |
| def tile( | |
| self, | |
| n: int, | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| axis: int = 1, | |
| wrap_kwargs: tp.KwargsLike = None, | |
| ) -> tp.SeriesFrame: | |
| """See `vectorbtpro.base.reshaping.tile`. | |
| Set `axis` to 1 for columns and 0 for index. | |
| Use `keys` as the outermost level.""" | |
| tiled = reshaping.tile(self.obj, n, axis=axis) | |
| if keys is not None: | |
| if axis == 1: | |
| new_columns = indexes.combine_indexes([keys, self.wrapper.columns]) | |
| return ArrayWrapper.from_obj(tiled).wrap( | |
| tiled.values, | |
| **merge_dicts(dict(columns=new_columns), wrap_kwargs), | |
| ) | |
| else: | |
| new_index = indexes.combine_indexes([keys, self.wrapper.index]) | |
| return ArrayWrapper.from_obj(tiled).wrap( | |
| tiled.values, | |
| **merge_dicts(dict(index=new_index), wrap_kwargs), | |
| ) | |
| return tiled | |
| def repeat( | |
| self, | |
| n: int, | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| axis: int = 1, | |
| wrap_kwargs: tp.KwargsLike = None, | |
| ) -> tp.SeriesFrame: | |
| """See `vectorbtpro.base.reshaping.repeat`. | |
| Set `axis` to 1 for columns and 0 for index. | |
| Use `keys` as the outermost level.""" | |
| repeated = reshaping.repeat(self.obj, n, axis=axis) | |
| if keys is not None: | |
| if axis == 1: | |
| new_columns = indexes.combine_indexes([self.wrapper.columns, keys]) | |
| return ArrayWrapper.from_obj(repeated).wrap( | |
| repeated.values, | |
| **merge_dicts(dict(columns=new_columns), wrap_kwargs), | |
| ) | |
| else: | |
| new_index = indexes.combine_indexes([self.wrapper.index, keys]) | |
| return ArrayWrapper.from_obj(repeated).wrap( | |
| repeated.values, | |
| **merge_dicts(dict(index=new_index), wrap_kwargs), | |
| ) | |
| return repeated | |
| def align_to(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None, **kwargs) -> tp.SeriesFrame: | |
| """Align to `other` on their axes using `vectorbtpro.base.indexes.align_index_to`. | |
| Usage: | |
| ```pycon | |
| >>> df1 = pd.DataFrame( | |
| ... [[1, 2], [3, 4]], | |
| ... index=['x', 'y'], | |
| ... columns=['a', 'b'] | |
| ... ) | |
| >>> df1 | |
| a b | |
| x 1 2 | |
| y 3 4 | |
| >>> df2 = pd.DataFrame( | |
| ... [[5, 6, 7, 8], [9, 10, 11, 12]], | |
| ... index=['x', 'y'], | |
| ... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']]) | |
| ... ) | |
| >>> df2 | |
| 1 2 | |
| a b a b | |
| x 5 6 7 8 | |
| y 9 10 11 12 | |
| >>> df1.vbt.align_to(df2) | |
| 1 2 | |
| a b a b | |
| x 1 2 1 2 | |
| y 3 4 3 4 | |
| ``` | |
| """ | |
| checks.assert_instance_of(other, (pd.Series, pd.DataFrame)) | |
| obj = reshaping.to_2d(self.obj) | |
| other = reshaping.to_2d(other) | |
| aligned_index = indexes.align_index_to(obj.index, other.index, **kwargs) | |
| aligned_columns = indexes.align_index_to(obj.columns, other.columns, **kwargs) | |
| obj = obj.iloc[aligned_index, aligned_columns] | |
| return self.wrapper.wrap( | |
| obj.values, | |
| group_by=False, | |
| **merge_dicts(dict(index=other.index, columns=other.columns), wrap_kwargs), | |
| ) | |
| @hybrid_method | |
| def align( | |
| cls_or_self, | |
| *others: tp.Union[tp.SeriesFrame, "BaseAccessor"], | |
| **kwargs, | |
| ) -> tp.Tuple[tp.SeriesFrame, ...]: | |
| """Align objects using `vectorbtpro.base.indexes.align_indexes`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| objs_2d = list(map(reshaping.to_2d, objs)) | |
| index_slices, new_index = indexes.align_indexes( | |
| *map(lambda x: x.index, objs_2d), | |
| return_new_index=True, | |
| **kwargs, | |
| ) | |
| column_slices, new_columns = indexes.align_indexes( | |
| *map(lambda x: x.columns, objs_2d), | |
| return_new_index=True, | |
| **kwargs, | |
| ) | |
| new_objs = [] | |
| for i in range(len(objs_2d)): | |
| new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False) | |
| if objs[i].ndim == 1 and new_obj.shape[1] == 1: | |
| new_obj = new_obj.iloc[:, 0].rename(objs[i].name) | |
| new_obj.index = new_index | |
| new_obj.columns = new_columns | |
| new_objs.append(new_obj) | |
| return tuple(new_objs) | |
| def cross_with(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame: | |
| """Align to `other` on their axes using `vectorbtpro.base.indexes.cross_index_with`. | |
| Usage: | |
| ```pycon | |
| >>> df1 = pd.DataFrame( | |
| ... [[1, 2, 3, 4], [5, 6, 7, 8]], | |
| ... index=['x', 'y'], | |
| ... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']]) | |
| ... ) | |
| >>> df1 | |
| 1 2 | |
| a b a b | |
| x 1 2 3 4 | |
| y 5 6 7 8 | |
| >>> df2 = pd.DataFrame( | |
| ... [[9, 10, 11, 12], [13, 14, 15, 16]], | |
| ... index=['x', 'y'], | |
| ... columns=pd.MultiIndex.from_arrays([[3, 3, 4, 4], ['a', 'b', 'a', 'b']]) | |
| ... ) | |
| >>> df2 | |
| 3 4 | |
| a b a b | |
| x 9 10 11 12 | |
| y 13 14 15 16 | |
| >>> df1.vbt.cross_with(df2) | |
| 1 2 | |
| 3 4 3 4 | |
| a b a b a b a b | |
| x 1 2 1 2 3 4 3 4 | |
| y 5 6 5 6 7 8 7 8 | |
| ``` | |
| """ | |
| checks.assert_instance_of(other, (pd.Series, pd.DataFrame)) | |
| obj = reshaping.to_2d(self.obj) | |
| other = reshaping.to_2d(other) | |
| index_slices, new_index = indexes.cross_index_with( | |
| obj.index, | |
| other.index, | |
| return_new_index=True, | |
| ) | |
| column_slices, new_columns = indexes.cross_index_with( | |
| obj.columns, | |
| other.columns, | |
| return_new_index=True, | |
| ) | |
| obj = obj.iloc[index_slices[0], column_slices[0]] | |
| return self.wrapper.wrap( | |
| obj.values, | |
| group_by=False, | |
| **merge_dicts(dict(index=new_index, columns=new_columns), wrap_kwargs), | |
| ) | |
| @hybrid_method | |
| def cross(cls_or_self, *others: tp.Union[tp.SeriesFrame, "BaseAccessor"]) -> tp.Tuple[tp.SeriesFrame, ...]: | |
| """Align objects using `vectorbtpro.base.indexes.cross_indexes`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| objs_2d = list(map(reshaping.to_2d, objs)) | |
| index_slices, new_index = indexes.cross_indexes( | |
| *map(lambda x: x.index, objs_2d), | |
| return_new_index=True, | |
| ) | |
| column_slices, new_columns = indexes.cross_indexes( | |
| *map(lambda x: x.columns, objs_2d), | |
| return_new_index=True, | |
| ) | |
| new_objs = [] | |
| for i in range(len(objs_2d)): | |
| new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False) | |
| if objs[i].ndim == 1 and new_obj.shape[1] == 1: | |
| new_obj = new_obj.iloc[:, 0].rename(objs[i].name) | |
| new_obj.index = new_index | |
| new_obj.columns = new_columns | |
| new_objs.append(new_obj) | |
| return tuple(new_objs) | |
| x = cross | |
| @hybrid_method | |
| def broadcast(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any: | |
| """See `vectorbtpro.base.reshaping.broadcast`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return reshaping.broadcast(*objs, **kwargs) | |
| def broadcast_to(self, other: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any: | |
| """See `vectorbtpro.base.reshaping.broadcast_to`.""" | |
| if isinstance(other, BaseAccessor): | |
| other = other.obj | |
| return reshaping.broadcast_to(self.obj, other, **kwargs) | |
| @hybrid_method | |
| def broadcast_combs(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any: | |
| """See `vectorbtpro.base.reshaping.broadcast_combs`.""" | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj, *others) | |
| return reshaping.broadcast_combs(*objs, **kwargs) | |
| def make_symmetric(self, *args, **kwargs) -> tp.Frame: | |
| """See `vectorbtpro.base.reshaping.make_symmetric`.""" | |
| return reshaping.make_symmetric(self.obj, *args, **kwargs) | |
| def unstack_to_array(self, *args, **kwargs) -> tp.Array: | |
| """See `vectorbtpro.base.reshaping.unstack_to_array`.""" | |
| return reshaping.unstack_to_array(self.obj, *args, **kwargs) | |
| def unstack_to_df(self, *args, **kwargs) -> tp.Frame: | |
| """See `vectorbtpro.base.reshaping.unstack_to_df`.""" | |
| return reshaping.unstack_to_df(self.obj, *args, **kwargs) | |
| def to_dict(self, *args, **kwargs) -> tp.Mapping: | |
| """See `vectorbtpro.base.reshaping.to_dict`.""" | |
| return reshaping.to_dict(self.obj, *args, **kwargs) | |
| # ############# Conversion ############# # | |
| def to_data( | |
| self, | |
| data_cls: tp.Optional[tp.Type[DataT]] = None, | |
| columns_are_symbols: bool = True, | |
| **kwargs, | |
| ) -> DataT: | |
| """Convert to a `vectorbtpro.data.base.Data` instance.""" | |
| if data_cls is None: | |
| from vectorbtpro.data.base import Data | |
| data_cls = Data | |
| return data_cls.from_data(self.obj, columns_are_symbols=columns_are_symbols, **kwargs) | |
| # ############# Combining ############# # | |
| def apply( | |
| self, | |
| apply_func: tp.Callable, | |
| *args, | |
| keep_pd: bool = False, | |
| to_2d: bool = False, | |
| broadcast_named_args: tp.KwargsLike = None, | |
| broadcast_kwargs: tp.KwargsLike = None, | |
| template_context: tp.KwargsLike = None, | |
| wrap_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.SeriesFrame: | |
| """Apply a function `apply_func`. | |
| Set `keep_pd` to True to keep inputs as pandas objects, otherwise convert to NumPy arrays. | |
| Set `to_2d` to True to reshape inputs to 2-dim arrays, otherwise keep as-is. | |
| `*args` and `**kwargs` are passed to `apply_func`. | |
| !!! note | |
| The resulted array must have the same shape as the original array. | |
| Usage: | |
| * Using instance method: | |
| ```pycon | |
| >>> sr = pd.Series([1, 2], index=['x', 'y']) | |
| >>> sr.vbt.apply(lambda x: x ** 2) | |
| x 1 | |
| y 4 | |
| dtype: int64 | |
| ``` | |
| * Using class method, templates, and broadcasting: | |
| ```pycon | |
| >>> sr.vbt.apply( | |
| ... lambda x, y: x + y, | |
| ... vbt.Rep('y'), | |
| ... broadcast_named_args=dict( | |
| ... y=pd.DataFrame([[3, 4]], columns=['a', 'b']) | |
| ... ) | |
| ... ) | |
| a b | |
| x 4 5 | |
| y 5 6 | |
| ``` | |
| """ | |
| if broadcast_named_args is None: | |
| broadcast_named_args = {} | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| if template_context is None: | |
| template_context = {} | |
| broadcast_named_args = {"obj": self.obj, **broadcast_named_args} | |
| if len(broadcast_named_args) > 1: | |
| broadcast_named_args, wrapper = reshaping.broadcast( | |
| broadcast_named_args, | |
| return_wrapper=True, | |
| **broadcast_kwargs, | |
| ) | |
| else: | |
| wrapper = self.wrapper | |
| if to_2d: | |
| broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()} | |
| elif not keep_pd: | |
| broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()} | |
| template_context = merge_dicts(broadcast_named_args, template_context) | |
| args = substitute_templates(args, template_context, eval_id="args") | |
| kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") | |
| out = apply_func(broadcast_named_args["obj"], *args, **kwargs) | |
| return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) | |
| @hybrid_method | |
| def concat( | |
| cls_or_self, | |
| *others: tp.ArrayLike, | |
| broadcast_kwargs: tp.KwargsLike = None, | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| ) -> tp.Frame: | |
| """Concatenate with `others` along columns. | |
| Usage: | |
| ```pycon | |
| >>> sr = pd.Series([1, 2], index=['x', 'y']) | |
| >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b']) | |
| >>> sr.vbt.concat(df, keys=['c', 'd']) | |
| c d | |
| a b a b | |
| x 1 1 3 4 | |
| y 2 2 5 6 | |
| ``` | |
| """ | |
| others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) | |
| if isinstance(cls_or_self, type): | |
| objs = others | |
| else: | |
| objs = (cls_or_self.obj,) + others | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| broadcasted = reshaping.broadcast(*objs, **broadcast_kwargs) | |
| broadcasted = tuple(map(reshaping.to_2d, broadcasted)) | |
| out = pd.concat(broadcasted, axis=1, keys=keys) | |
| if not isinstance(out.columns, pd.MultiIndex) and np.all(out.columns == 0): | |
| out.columns = pd.RangeIndex(start=0, stop=len(out.columns), step=1) | |
| return out | |
| def apply_and_concat( | |
| self, | |
| ntimes: int, | |
| apply_func: tp.Callable, | |
| *args, | |
| keep_pd: bool = False, | |
| to_2d: bool = False, | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| broadcast_named_args: tp.KwargsLike = None, | |
| broadcast_kwargs: tp.KwargsLike = None, | |
| template_context: tp.KwargsLike = None, | |
| wrap_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.MaybeTuple[tp.Frame]: | |
| """Apply `apply_func` `ntimes` times and concatenate the results along columns. | |
| See `vectorbtpro.base.combining.apply_and_concat`. | |
| `ntimes` is the number of times to call `apply_func`, while `n_outputs` is the number of outputs to expect. | |
| `*args` and `**kwargs` are passed to `vectorbtpro.base.combining.apply_and_concat`. | |
| !!! note | |
| The resulted arrays to be concatenated must have the same shape as broadcast input arrays. | |
| Usage: | |
| * Using instance method: | |
| ```pycon | |
| >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b']) | |
| >>> df.vbt.apply_and_concat( | |
| ... 3, | |
| ... lambda i, a, b: a * b[i], | |
| ... [1, 2, 3], | |
| ... keys=['c', 'd', 'e'] | |
| ... ) | |
| c d e | |
| a b a b a b | |
| x 3 4 6 8 9 12 | |
| y 5 6 10 12 15 18 | |
| ``` | |
| * Using class method, templates, and broadcasting: | |
| ```pycon | |
| >>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z']) | |
| >>> sr.vbt.apply_and_concat( | |
| ... 3, | |
| ... lambda i, a, b: a * b + i, | |
| ... vbt.Rep('df'), | |
| ... broadcast_named_args=dict( | |
| ... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) | |
| ... ) | |
| ... ) | |
| apply_idx 0 1 2 | |
| a b c a b c a b c | |
| x 1 2 3 2 3 4 3 4 5 | |
| y 2 4 6 3 5 7 4 6 8 | |
| z 3 6 9 4 7 10 5 8 11 | |
| ``` | |
| * To change the execution engine or specify other engine-related arguments, use `execute_kwargs`: | |
| ```pycon | |
| >>> import time | |
| >>> def apply_func(i, a): | |
| ... time.sleep(1) | |
| ... return a | |
| >>> sr = pd.Series([1, 2, 3]) | |
| >>> %timeit sr.vbt.apply_and_concat(3, apply_func) | |
| 3.02 s ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
| >>> %timeit sr.vbt.apply_and_concat(3, apply_func, execute_kwargs=dict(engine='dask')) | |
| 1.02 s ± 927 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
| ``` | |
| """ | |
| if broadcast_named_args is None: | |
| broadcast_named_args = {} | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| if template_context is None: | |
| template_context = {} | |
| broadcast_named_args = {"obj": self.obj, **broadcast_named_args} | |
| if len(broadcast_named_args) > 1: | |
| broadcast_named_args, wrapper = reshaping.broadcast( | |
| broadcast_named_args, | |
| return_wrapper=True, | |
| **broadcast_kwargs, | |
| ) | |
| else: | |
| wrapper = self.wrapper | |
| if to_2d: | |
| broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()} | |
| elif not keep_pd: | |
| broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()} | |
| template_context = merge_dicts(broadcast_named_args, dict(ntimes=ntimes), template_context) | |
| args = substitute_templates(args, template_context, eval_id="args") | |
| kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") | |
| out = combining.apply_and_concat(ntimes, apply_func, broadcast_named_args["obj"], *args, **kwargs) | |
| if keys is not None: | |
| new_columns = indexes.combine_indexes([keys, wrapper.columns]) | |
| else: | |
| top_columns = pd.Index(np.arange(ntimes), name="apply_idx") | |
| new_columns = indexes.combine_indexes([top_columns, wrapper.columns]) | |
| if out is None: | |
| return None | |
| wrap_kwargs = merge_dicts(dict(columns=new_columns), wrap_kwargs) | |
| if isinstance(out, list): | |
| return tuple(map(lambda x: wrapper.wrap(x, group_by=False, **wrap_kwargs), out)) | |
| return wrapper.wrap(out, group_by=False, **wrap_kwargs) | |
| @hybrid_method | |
| def combine( | |
| cls_or_self, | |
| obj: tp.MaybeTupleList[tp.Union[tp.ArrayLike, "BaseAccessor"]], | |
| combine_func: tp.Callable, | |
| *args, | |
| allow_multiple: bool = True, | |
| keep_pd: bool = False, | |
| to_2d: bool = False, | |
| concat: tp.Optional[bool] = None, | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| broadcast_named_args: tp.KwargsLike = None, | |
| broadcast_kwargs: tp.KwargsLike = None, | |
| template_context: tp.KwargsLike = None, | |
| wrap_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.SeriesFrame: | |
| """Combine with `other` using `combine_func`. | |
| Args: | |
| obj (array_like): Object(s) to combine this array with. | |
| combine_func (callable): Function to combine two arrays. | |
| Can be Numba-compiled. | |
| *args: Variable arguments passed to `combine_func`. | |
| allow_multiple (bool): Whether a tuple/list/Index will be considered as multiple objects in `other`. | |
| Takes effect only when using the instance method. | |
| keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays. | |
| to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is. | |
| concat (bool): Whether to concatenate the results along the column axis. | |
| Otherwise, pairwise combine into a Series/DataFrame of the same shape. | |
| If True, see `vectorbtpro.base.combining.combine_and_concat`. | |
| If False, see `vectorbtpro.base.combining.combine_multiple`. | |
| If None, becomes True if there are multiple objects to combine. | |
| Can only concatenate using the instance method. | |
| keys (index_like): Outermost column level. | |
| broadcast_named_args (dict): Dictionary with arguments to broadcast against each other. | |
| broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`. | |
| template_context (dict): Context used to substitute templates in `args` and `kwargs`. | |
| wrap_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. | |
| **kwargs: Keyword arguments passed to `combine_func`. | |
| !!! note | |
| If `combine_func` is Numba-compiled, will broadcast using `WRITEABLE` and `C_CONTIGUOUS` | |
| flags, which can lead to an expensive computation overhead if passed objects are large and | |
| have different shape/memory order. You also must ensure that all objects have the same data type. | |
| Also remember to bring each in `*args` to a Numba-compatible format. | |
| Usage: | |
| * Using instance method: | |
| ```pycon | |
| >>> sr = pd.Series([1, 2], index=['x', 'y']) | |
| >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b']) | |
| >>> # using instance method | |
| >>> sr.vbt.combine(df, np.add) | |
| a b | |
| x 4 5 | |
| y 7 8 | |
| >>> sr.vbt.combine([df, df * 2], np.add, concat=False) | |
| a b | |
| x 10 13 | |
| y 17 20 | |
| >>> sr.vbt.combine([df, df * 2], np.add) | |
| combine_idx 0 1 | |
| a b a b | |
| x 4 5 7 9 | |
| y 7 8 12 14 | |
| >>> sr.vbt.combine([df, df * 2], np.add, keys=['c', 'd']) | |
| c d | |
| a b a b | |
| x 4 5 7 9 | |
| y 7 8 12 14 | |
| >>> sr.vbt.combine(vbt.Param([1, 2], name='param'), np.add) | |
| param 1 2 | |
| x 2 3 | |
| y 3 4 | |
| >>> # using class method | |
| >>> sr.vbt.combine([df, df * 2], np.add, concat=False) | |
| a b | |
| x 10 13 | |
| y 17 20 | |
| ``` | |
| * Using class method, templates, and broadcasting: | |
| ```pycon | |
| >>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z']) | |
| >>> sr.vbt.combine( | |
| ... [1, 2, 3], | |
| ... lambda x, y, z: x + y + z, | |
| ... vbt.Rep('df'), | |
| ... broadcast_named_args=dict( | |
| ... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) | |
| ... ) | |
| ... ) | |
| combine_idx 0 1 2 | |
| a b c a b c a b c | |
| x 3 4 5 4 5 6 5 6 7 | |
| y 4 5 6 5 6 7 6 7 8 | |
| z 5 6 7 6 7 8 7 8 9 | |
| ``` | |
| * To change the execution engine or specify other engine-related arguments, use `execute_kwargs`: | |
| ```pycon | |
| >>> import time | |
| >>> def combine_func(a, b): | |
| ... time.sleep(1) | |
| ... return a + b | |
| >>> sr = pd.Series([1, 2, 3]) | |
| >>> %timeit sr.vbt.combine([1, 1, 1], combine_func) | |
| 3.01 s ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
| >>> %timeit sr.vbt.combine([1, 1, 1], combine_func, execute_kwargs=dict(engine='dask')) | |
| 1.02 s ± 2.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
| ``` | |
| """ | |
| from vectorbtpro.indicators.factory import IndicatorBase | |
| if broadcast_named_args is None: | |
| broadcast_named_args = {} | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| if template_context is None: | |
| template_context = {} | |
| if isinstance(cls_or_self, type): | |
| objs = obj | |
| else: | |
| if allow_multiple and isinstance(obj, (tuple, list)): | |
| objs = obj | |
| if concat is None: | |
| concat = True | |
| else: | |
| objs = (obj,) | |
| new_objs = [] | |
| for obj in objs: | |
| if isinstance(obj, BaseAccessor): | |
| obj = obj.obj | |
| elif isinstance(obj, IndicatorBase): | |
| obj = obj.main_output | |
| new_objs.append(obj) | |
| objs = tuple(new_objs) | |
| if not isinstance(cls_or_self, type): | |
| objs = (cls_or_self.obj,) + objs | |
| if checks.is_numba_func(combine_func): | |
| # Numba requires writeable arrays and in the same order | |
| broadcast_kwargs = merge_dicts(dict(require_kwargs=dict(requirements=["W", "C"])), broadcast_kwargs) | |
| # Broadcast and substitute templates | |
| broadcast_named_args = {**{"obj_" + str(i): obj for i, obj in enumerate(objs)}, **broadcast_named_args} | |
| broadcast_named_args, wrapper = reshaping.broadcast( | |
| broadcast_named_args, | |
| return_wrapper=True, | |
| **broadcast_kwargs, | |
| ) | |
| if to_2d: | |
| broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()} | |
| elif not keep_pd: | |
| broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()} | |
| template_context = merge_dicts(broadcast_named_args, template_context) | |
| args = substitute_templates(args, template_context, eval_id="args") | |
| kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") | |
| inputs = [broadcast_named_args["obj_" + str(i)] for i in range(len(objs))] | |
| if concat is None: | |
| concat = len(inputs) > 2 | |
| if concat: | |
| # Concat the results horizontally | |
| if isinstance(cls_or_self, type): | |
| raise TypeError("Use instance method to concatenate") | |
| out = combining.combine_and_concat(inputs[0], inputs[1:], combine_func, *args, **kwargs) | |
| if keys is not None: | |
| new_columns = indexes.combine_indexes([keys, wrapper.columns]) | |
| else: | |
| top_columns = pd.Index(np.arange(len(objs) - 1), name="combine_idx") | |
| new_columns = indexes.combine_indexes([top_columns, wrapper.columns]) | |
| return wrapper.wrap(out, **merge_dicts(dict(columns=new_columns, force_2d=True), wrap_kwargs)) | |
| else: | |
| # Combine arguments pairwise into one object | |
| out = combining.combine_multiple(inputs, combine_func, *args, **kwargs) | |
| return wrapper.wrap(out, **resolve_dict(wrap_kwargs)) | |
| @classmethod | |
| def eval( | |
| cls, | |
| expr: str, | |
| frames_back: int = 1, | |
| use_numexpr: bool = False, | |
| numexpr_kwargs: tp.KwargsLike = None, | |
| local_dict: tp.Optional[tp.Mapping] = None, | |
| global_dict: tp.Optional[tp.Mapping] = None, | |
| broadcast_kwargs: tp.KwargsLike = None, | |
| wrap_kwargs: tp.KwargsLike = None, | |
| ): | |
| """Evaluate a simple array expression element-wise using NumExpr or NumPy. | |
| If NumExpr is enables, only one-line statements are supported. Otherwise, uses | |
| `vectorbtpro.utils.eval_.evaluate`. | |
| !!! note | |
| All required variables will broadcast against each other prior to the evaluation. | |
| Usage: | |
| ```pycon | |
| >>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z']) | |
| >>> df = pd.DataFrame([[4, 5, 6]], index=['x', 'y', 'z'], columns=['a', 'b', 'c']) | |
| >>> vbt.pd_acc.eval('sr + df') | |
| a b c | |
| x 5 6 7 | |
| y 6 7 8 | |
| z 7 8 9 | |
| ``` | |
| """ | |
| if numexpr_kwargs is None: | |
| numexpr_kwargs = {} | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| if wrap_kwargs is None: | |
| wrap_kwargs = {} | |
| expr = inspect.cleandoc(expr) | |
| parsed = ast.parse(expr) | |
| body_nodes = list(parsed.body) | |
| load_vars = set() | |
| store_vars = set() | |
| for body_node in body_nodes: | |
| for child_node in ast.walk(body_node): | |
| if type(child_node) is ast.Name: | |
| if isinstance(child_node.ctx, ast.Load): | |
| if child_node.id not in store_vars: | |
| load_vars.add(child_node.id) | |
| if isinstance(child_node.ctx, ast.Store): | |
| store_vars.add(child_node.id) | |
| load_vars = list(load_vars) | |
| objs = get_context_vars(load_vars, frames_back=frames_back, local_dict=local_dict, global_dict=global_dict) | |
| objs = dict(zip(load_vars, objs)) | |
| objs, wrapper = reshaping.broadcast(objs, return_wrapper=True, **broadcast_kwargs) | |
| objs = {k: np.asarray(v) for k, v in objs.items()} | |
| if use_numexpr: | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("numexpr") | |
| import numexpr | |
| out = numexpr.evaluate(expr, local_dict=objs, **numexpr_kwargs) | |
| else: | |
| out = evaluate(expr, context=objs) | |
| return wrapper.wrap(out, **wrap_kwargs) | |
| class BaseSRAccessor(BaseAccessor): | |
| """Accessor on top of Series. | |
| Accessible via `pd.Series.vbt` and all child accessors.""" | |
| def __init__( | |
| self, | |
| wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], | |
| obj: tp.Optional[tp.ArrayLike] = None, | |
| _full_init: bool = True, | |
| **kwargs, | |
| ) -> None: | |
| if _full_init: | |
| if isinstance(wrapper, ArrayWrapper): | |
| if wrapper.ndim == 2: | |
| if wrapper.shape[1] == 1: | |
| wrapper = wrapper.replace(ndim=1) | |
| else: | |
| raise TypeError("Series accessors work only one one-dimensional data") | |
| BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs) | |
| @hybrid_property | |
| def ndim(cls_or_self) -> int: | |
| return 1 | |
| @hybrid_method | |
| def is_series(cls_or_self) -> bool: | |
| return True | |
| @hybrid_method | |
| def is_frame(cls_or_self) -> bool: | |
| return False | |
| class BaseDFAccessor(BaseAccessor): | |
| """Accessor on top of DataFrames. | |
| Accessible via `pd.DataFrame.vbt` and all child accessors.""" | |
| def __init__( | |
| self, | |
| wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], | |
| obj: tp.Optional[tp.ArrayLike] = None, | |
| _full_init: bool = True, | |
| **kwargs, | |
| ) -> None: | |
| if _full_init: | |
| if isinstance(wrapper, ArrayWrapper): | |
| if wrapper.ndim == 1: | |
| wrapper = wrapper.replace(ndim=2) | |
| BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs) | |
| @hybrid_property | |
| def ndim(cls_or_self) -> int: | |
| return 2 | |
| @hybrid_method | |
| def is_series(cls_or_self) -> bool: | |
| return False | |
| @hybrid_method | |
| def is_frame(cls_or_self) -> bool: | |
| return True | |
| </file> | |
| <file path="base/chunking.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Extensions for chunking of base operations.""" | |
| import uuid | |
| import numpy as np | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.attr_ import DefineMixin, define | |
| from vectorbtpro.utils.chunking import ( | |
| ArgGetter, | |
| ArgSizer, | |
| ArraySizer, | |
| ChunkMeta, | |
| ChunkMapper, | |
| ChunkSlicer, | |
| ShapeSlicer, | |
| ArraySelector, | |
| ArraySlicer, | |
| Chunked, | |
| ) | |
| from vectorbtpro.utils.parsing import Regex | |
| __all__ = [ | |
| "GroupLensSizer", | |
| "GroupLensSlicer", | |
| "ChunkedGroupLens", | |
| "GroupLensMapper", | |
| "GroupMapSlicer", | |
| "ChunkedGroupMap", | |
| "GroupIdxsMapper", | |
| "FlexArraySizer", | |
| "FlexArraySelector", | |
| "FlexArraySlicer", | |
| "ChunkedFlexArray", | |
| "shape_gl_slicer", | |
| "flex_1d_array_gl_slicer", | |
| "flex_array_gl_slicer", | |
| "array_gl_slicer", | |
| ] | |
| class GroupLensSizer(ArgSizer): | |
| """Class for getting the size from group lengths. | |
| Argument can be either a group map tuple or a group lengths array.""" | |
| @classmethod | |
| def get_obj_size(cls, obj: tp.Union[tp.GroupLens, tp.GroupMap], single_type: tp.Optional[type] = None) -> int: | |
| """Get size of an object.""" | |
| if single_type is not None: | |
| if checks.is_instance_of(obj, single_type): | |
| return 1 | |
| if isinstance(obj, tuple): | |
| return len(obj[1]) | |
| return len(obj) | |
| def get_size(self, ann_args: tp.AnnArgs, **kwargs) -> int: | |
| return self.get_obj_size(self.get_arg(ann_args), single_type=self.single_type) | |
| class GroupLensSlicer(ChunkSlicer): | |
| """Class for slicing multiple elements from group lengths based on the chunk range.""" | |
| def get_size(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], **kwargs) -> int: | |
| return GroupLensSizer.get_obj_size(obj, single_type=self.single_type) | |
| def take(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap: | |
| if isinstance(obj, tuple): | |
| return obj[1][chunk_meta.start : chunk_meta.end] | |
| return obj[chunk_meta.start : chunk_meta.end] | |
| class ChunkedGroupLens(Chunked): | |
| """Class representing chunkable group lengths.""" | |
| def resolve_take_spec(self) -> tp.TakeSpec: | |
| if self.take_spec_missing: | |
| if self.select: | |
| raise ValueError("Selection is not supported") | |
| return GroupLensSlicer | |
| return self.take_spec | |
| def get_group_lens_slice(group_lens: tp.GroupLens, chunk_meta: ChunkMeta) -> slice: | |
| """Get slice of each chunk in group lengths.""" | |
| group_lens_cumsum = np.cumsum(group_lens[: chunk_meta.end]) | |
| start = group_lens_cumsum[chunk_meta.start] - group_lens[chunk_meta.start] | |
| end = group_lens_cumsum[-1] | |
| return slice(start, end) | |
| @define | |
| class GroupLensMapper(ChunkMapper, ArgGetter, DefineMixin): | |
| """Class for mapping chunk metadata to per-group column lengths. | |
| Argument can be either a group map tuple or a group lengths array.""" | |
| def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta: | |
| group_lens = self.get_arg(ann_args) | |
| if isinstance(group_lens, tuple): | |
| group_lens = group_lens[1] | |
| group_lens_slice = get_group_lens_slice(group_lens, chunk_meta) | |
| return ChunkMeta( | |
| uuid=str(uuid.uuid4()), | |
| idx=chunk_meta.idx, | |
| start=group_lens_slice.start, | |
| end=group_lens_slice.stop, | |
| indices=None, | |
| ) | |
| group_lens_mapper = GroupLensMapper(arg_query=Regex(r"(group_lens|group_map)")) | |
| """Default instance of `GroupLensMapper`.""" | |
| class GroupMapSlicer(ChunkSlicer): | |
| """Class for slicing multiple elements from a group map based on the chunk range.""" | |
| def get_size(self, obj: tp.GroupMap, **kwargs) -> int: | |
| return GroupLensSizer.get_obj_size(obj, single_type=self.single_type) | |
| def take(self, obj: tp.GroupMap, chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap: | |
| group_idxs, group_lens = obj | |
| group_lens = group_lens[chunk_meta.start : chunk_meta.end] | |
| return np.arange(np.sum(group_lens)), group_lens | |
| class ChunkedGroupMap(Chunked): | |
| """Class representing a chunkable group map.""" | |
| def resolve_take_spec(self) -> tp.TakeSpec: | |
| if self.take_spec_missing: | |
| if self.select: | |
| raise ValueError("Selection is not supported") | |
| return GroupMapSlicer | |
| return self.take_spec | |
| @define | |
| class GroupIdxsMapper(ChunkMapper, ArgGetter, DefineMixin): | |
| """Class for mapping chunk metadata to per-group column indices. | |
| Argument must be a group map tuple.""" | |
| def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta: | |
| group_map = self.get_arg(ann_args) | |
| group_idxs, group_lens = group_map | |
| group_lens_slice = get_group_lens_slice(group_lens, chunk_meta) | |
| return ChunkMeta( | |
| uuid=str(uuid.uuid4()), | |
| idx=chunk_meta.idx, | |
| start=None, | |
| end=None, | |
| indices=group_idxs[group_lens_slice], | |
| ) | |
| group_idxs_mapper = GroupIdxsMapper(arg_query="group_map") | |
| """Default instance of `GroupIdxsMapper`.""" | |
| class FlexArraySizer(ArraySizer): | |
| """Class for getting the size from the length of an axis in a flexible array.""" | |
| @classmethod | |
| def get_obj_size(cls, obj: tp.AnyArray, axis: int, single_type: tp.Optional[type] = None) -> int: | |
| """Get size of an object.""" | |
| if single_type is not None: | |
| if checks.is_instance_of(obj, single_type): | |
| return 1 | |
| obj = np.asarray(obj) | |
| if len(obj.shape) == 0: | |
| return 1 | |
| if axis is None: | |
| if len(obj.shape) == 1: | |
| axis = 0 | |
| checks.assert_not_none(axis, arg_name="axis") | |
| checks.assert_in(axis, (0, 1), arg_name="axis") | |
| if len(obj.shape) == 1: | |
| if axis == 1: | |
| return 1 | |
| return obj.shape[0] | |
| if len(obj.shape) == 2: | |
| if axis == 1: | |
| return obj.shape[1] | |
| return obj.shape[0] | |
| raise ValueError(f"FlexArraySizer supports max 2 dimensions, not {len(obj.shape)}") | |
| @define | |
| class FlexArraySelector(ArraySelector, DefineMixin): | |
| """Class for selecting one element from a NumPy array's axis flexibly based on the chunk index. | |
| The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb` | |
| and `vectorbtpro.base.flex_indexing.flex_select_nb`.""" | |
| def get_size(self, obj: tp.ArrayLike, **kwargs) -> int: | |
| return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type) | |
| def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]: | |
| return None | |
| def take( | |
| self, | |
| obj: tp.ArrayLike, | |
| chunk_meta: ChunkMeta, | |
| ann_args: tp.Optional[tp.AnnArgs] = None, | |
| **kwargs, | |
| ) -> tp.ArrayLike: | |
| if np.isscalar(obj): | |
| return obj | |
| obj = np.asarray(obj) | |
| if len(obj.shape) == 0: | |
| return obj | |
| axis = self.axis | |
| if axis is None: | |
| if len(obj.shape) == 1: | |
| axis = 0 | |
| checks.assert_not_none(axis, arg_name="axis") | |
| checks.assert_in(axis, (0, 1), arg_name="axis") | |
| if len(obj.shape) == 1: | |
| if axis == 1 or obj.shape[0] == 1: | |
| return obj | |
| if self.keep_dims: | |
| return obj[chunk_meta.idx : chunk_meta.idx + 1] | |
| return obj[chunk_meta.idx] | |
| if len(obj.shape) == 2: | |
| if axis == 1: | |
| if obj.shape[1] == 1: | |
| return obj | |
| if self.keep_dims: | |
| return obj[:, chunk_meta.idx : chunk_meta.idx + 1] | |
| return obj[:, chunk_meta.idx] | |
| if obj.shape[0] == 1: | |
| return obj | |
| if self.keep_dims: | |
| return obj[chunk_meta.idx : chunk_meta.idx + 1, :] | |
| return obj[chunk_meta.idx, :] | |
| raise ValueError(f"FlexArraySelector supports max 2 dimensions, not {len(obj.shape)}") | |
| @define | |
| class FlexArraySlicer(ArraySlicer, DefineMixin): | |
| """Class for selecting one element from a NumPy array's axis flexibly based on the chunk index. | |
| The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb` | |
| and `vectorbtpro.base.flex_indexing.flex_select_nb`.""" | |
| def get_size(self, obj: tp.ArrayLike, **kwargs) -> int: | |
| return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type) | |
| def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]: | |
| return None | |
| def take( | |
| self, | |
| obj: tp.ArrayLike, | |
| chunk_meta: ChunkMeta, | |
| ann_args: tp.Optional[tp.AnnArgs] = None, | |
| **kwargs, | |
| ) -> tp.ArrayLike: | |
| if np.isscalar(obj): | |
| return obj | |
| obj = np.asarray(obj) | |
| if len(obj.shape) == 0: | |
| return obj | |
| axis = self.axis | |
| if axis is None: | |
| if len(obj.shape) == 1: | |
| axis = 0 | |
| checks.assert_not_none(axis, arg_name="axis") | |
| checks.assert_in(axis, (0, 1), arg_name="axis") | |
| if len(obj.shape) == 1: | |
| if axis == 1 or obj.shape[0] == 1: | |
| return obj | |
| return obj[chunk_meta.start : chunk_meta.end] | |
| if len(obj.shape) == 2: | |
| if axis == 1: | |
| if obj.shape[1] == 1: | |
| return obj | |
| return obj[:, chunk_meta.start : chunk_meta.end] | |
| if obj.shape[0] == 1: | |
| return obj | |
| return obj[chunk_meta.start : chunk_meta.end, :] | |
| raise ValueError(f"FlexArraySlicer supports max 2 dimensions, not {len(obj.shape)}") | |
| class ChunkedFlexArray(Chunked): | |
| """Class representing a chunkable flexible array.""" | |
| def resolve_take_spec(self) -> tp.TakeSpec: | |
| if self.take_spec_missing: | |
| if self.select: | |
| return FlexArraySelector | |
| return FlexArraySlicer | |
| return self.take_spec | |
| shape_gl_slicer = ShapeSlicer(axis=1, mapper=group_lens_mapper) | |
| """Flexible 2-dim shape slicer along the column axis based on group lengths.""" | |
| flex_1d_array_gl_slicer = FlexArraySlicer(mapper=group_lens_mapper) | |
| """Flexible 1-dim array slicer along the column axis based on group lengths.""" | |
| flex_array_gl_slicer = FlexArraySlicer(axis=1, mapper=group_lens_mapper) | |
| """Flexible 2-dim array slicer along the column axis based on group lengths.""" | |
| array_gl_slicer = ArraySlicer(axis=1, mapper=group_lens_mapper) | |
| """2-dim array slicer along the column axis based on group lengths.""" | |
| </file> | |
| <file path="base/combining.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Functions for combining arrays. | |
| Combine functions combine two or more NumPy arrays using a custom function. The emphasis here is | |
| done upon stacking the results into one NumPy array - since vectorbt is all about brute-forcing | |
| large spaces of hyper-parameters, concatenating the results of each hyper-parameter combination into | |
| a single DataFrame is important. All functions are available in both Python and Numba-compiled form.""" | |
| import numpy as np | |
| from numba.typed import List | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.registries.jit_registry import jit_reg, register_jitted | |
| from vectorbtpro.utils.execution import Task, execute | |
| from vectorbtpro.utils.template import RepFunc | |
| __all__ = [] | |
| @register_jitted | |
| def custom_apply_and_concat_none_nb( | |
| indices: tp.Array1d, | |
| apply_func_nb: tp.Callable, | |
| *args, | |
| ) -> None: | |
| """Run `apply_func_nb` that returns nothing for each index. | |
| Meant for in-place outputs.""" | |
| for i in indices: | |
| apply_func_nb(i, *args) | |
| @register_jitted | |
| def apply_and_concat_none_nb( | |
| ntimes: int, | |
| apply_func_nb: tp.Callable, | |
| *args, | |
| ) -> None: | |
| """Run `apply_func_nb` that returns nothing number of times. | |
| Uses `custom_apply_and_concat_none_nb`.""" | |
| custom_apply_and_concat_none_nb(np.arange(ntimes), apply_func_nb, *args) | |
| @register_jitted | |
| def to_2d_one_nb(a: tp.Array) -> tp.Array2d: | |
| """Expand the dimensions of the array along the axis 1.""" | |
| if a.ndim > 1: | |
| return a | |
| return np.expand_dims(a, axis=1) | |
| @register_jitted | |
| def custom_apply_and_concat_one_nb( | |
| indices: tp.Array1d, | |
| apply_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.Array2d: | |
| """Run `apply_func_nb` that returns one array for each index.""" | |
| output_0 = to_2d_one_nb(apply_func_nb(indices[0], *args)) | |
| output = np.empty((output_0.shape[0], len(indices) * output_0.shape[1]), dtype=output_0.dtype) | |
| for i in range(len(indices)): | |
| if i == 0: | |
| outputs_i = output_0 | |
| else: | |
| outputs_i = to_2d_one_nb(apply_func_nb(indices[i], *args)) | |
| output[:, i * outputs_i.shape[1] : (i + 1) * outputs_i.shape[1]] = outputs_i | |
| return output | |
| @register_jitted | |
| def apply_and_concat_one_nb( | |
| ntimes: int, | |
| apply_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.Array2d: | |
| """Run `apply_func_nb` that returns one array number of times. | |
| Uses `custom_apply_and_concat_one_nb`.""" | |
| return custom_apply_and_concat_one_nb(np.arange(ntimes), apply_func_nb, *args) | |
| @register_jitted | |
| def to_2d_multiple_nb(a: tp.Iterable[tp.Array]) -> tp.List[tp.Array2d]: | |
| """Expand the dimensions of each array in `a` along axis 1.""" | |
| lst = list() | |
| for _a in a: | |
| lst.append(to_2d_one_nb(_a)) | |
| return lst | |
| @register_jitted | |
| def custom_apply_and_concat_multiple_nb( | |
| indices: tp.Array1d, | |
| apply_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.List[tp.Array2d]: | |
| """Run `apply_func_nb` that returns multiple arrays for each index.""" | |
| outputs = list() | |
| outputs_0 = to_2d_multiple_nb(apply_func_nb(indices[0], *args)) | |
| for j in range(len(outputs_0)): | |
| outputs.append( | |
| np.empty((outputs_0[j].shape[0], len(indices) * outputs_0[j].shape[1]), dtype=outputs_0[j].dtype) | |
| ) | |
| for i in range(len(indices)): | |
| if i == 0: | |
| outputs_i = outputs_0 | |
| else: | |
| outputs_i = to_2d_multiple_nb(apply_func_nb(indices[i], *args)) | |
| for j in range(len(outputs_i)): | |
| outputs[j][:, i * outputs_i[j].shape[1] : (i + 1) * outputs_i[j].shape[1]] = outputs_i[j] | |
| return outputs | |
| @register_jitted | |
| def apply_and_concat_multiple_nb( | |
| ntimes: int, | |
| apply_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.List[tp.Array2d]: | |
| """Run `apply_func_nb` that returns multiple arrays number of times. | |
| Uses `custom_apply_and_concat_multiple_nb`.""" | |
| return custom_apply_and_concat_multiple_nb(np.arange(ntimes), apply_func_nb, *args) | |
| def apply_and_concat_each( | |
| tasks: tp.TasksLike, | |
| n_outputs: tp.Optional[int] = None, | |
| execute_kwargs: tp.KwargsLike = None, | |
| ) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]: | |
| """Apply each function on its own set of positional and keyword arguments. | |
| Executes the function using `vectorbtpro.utils.execution.execute`.""" | |
| from vectorbtpro.base.merging import column_stack_arrays | |
| if execute_kwargs is None: | |
| execute_kwargs = {} | |
| out = execute(tasks, **execute_kwargs) | |
| if n_outputs is None: | |
| if out[0] is None: | |
| n_outputs = 0 | |
| elif isinstance(out[0], (tuple, list, List)): | |
| n_outputs = len(out[0]) | |
| else: | |
| n_outputs = 1 | |
| if n_outputs == 0: | |
| return None | |
| if n_outputs == 1: | |
| if isinstance(out[0], (tuple, list, List)) and len(out[0]) == 1: | |
| out = list(map(lambda x: x[0], out)) | |
| return column_stack_arrays(out) | |
| return list(map(column_stack_arrays, zip(*out))) | |
| def apply_and_concat( | |
| ntimes: int, | |
| apply_func: tp.Callable, | |
| *args, | |
| n_outputs: tp.Optional[int] = None, | |
| jitted_loop: bool = False, | |
| jitted_warmup: bool = False, | |
| execute_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]: | |
| """Run `apply_func` function a number of times and concatenate the results depending upon how | |
| many array-like objects it generates. | |
| `apply_func` must accept arguments `i`, `*args`, and `**kwargs`. | |
| Set `jitted_loop` to True to use the JIT-compiled version. | |
| All jitted iteration functions are resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`. | |
| !!! note | |
| `n_outputs` must be set when `jitted_loop` is True. | |
| Numba doesn't support variable keyword arguments.""" | |
| if jitted_loop: | |
| if n_outputs is None: | |
| raise ValueError("Jitted iteration requires n_outputs") | |
| if n_outputs == 0: | |
| func = jit_reg.resolve(custom_apply_and_concat_none_nb) | |
| elif n_outputs == 1: | |
| func = jit_reg.resolve(custom_apply_and_concat_one_nb) | |
| else: | |
| func = jit_reg.resolve(custom_apply_and_concat_multiple_nb) | |
| if jitted_warmup: | |
| func(np.array([0]), apply_func, *args, **kwargs) | |
| def _tasks_template(chunk_meta): | |
| tasks = [] | |
| for _chunk_meta in chunk_meta: | |
| if _chunk_meta.indices is not None: | |
| chunk_indices = np.asarray(_chunk_meta.indices) | |
| else: | |
| if _chunk_meta.start is None or _chunk_meta.end is None: | |
| raise ValueError("Each chunk must have a start and an end index") | |
| chunk_indices = np.arange(_chunk_meta.start, _chunk_meta.end) | |
| tasks.append(Task(func, chunk_indices, apply_func, *args, **kwargs)) | |
| return tasks | |
| tasks = RepFunc(_tasks_template) | |
| else: | |
| tasks = [(apply_func, (i, *args), kwargs) for i in range(ntimes)] | |
| if execute_kwargs is None: | |
| execute_kwargs = {} | |
| execute_kwargs["size"] = ntimes | |
| return apply_and_concat_each( | |
| tasks, | |
| n_outputs=n_outputs, | |
| execute_kwargs=execute_kwargs, | |
| ) | |
| @register_jitted | |
| def select_and_combine_nb( | |
| i: int, | |
| obj: tp.Any, | |
| others: tp.Sequence, | |
| combine_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.AnyArray: | |
| """Numba-compiled version of `select_and_combine`.""" | |
| return combine_func_nb(obj, others[i], *args) | |
| @register_jitted | |
| def combine_and_concat_nb( | |
| obj: tp.Any, | |
| others: tp.Sequence, | |
| combine_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.Array2d: | |
| """Numba-compiled version of `combine_and_concat`.""" | |
| return apply_and_concat_one_nb(len(others), select_and_combine_nb, obj, others, combine_func_nb, *args) | |
| def select_and_combine( | |
| i: int, | |
| obj: tp.Any, | |
| others: tp.Sequence, | |
| combine_func: tp.Callable, | |
| *args, | |
| **kwargs, | |
| ) -> tp.AnyArray: | |
| """Combine `obj` with an array at position `i` in `others` using `combine_func`.""" | |
| return combine_func(obj, others[i], *args, **kwargs) | |
| def combine_and_concat( | |
| obj: tp.Any, | |
| others: tp.Sequence, | |
| combine_func: tp.Callable, | |
| *args, | |
| jitted_loop: bool = False, | |
| **kwargs, | |
| ) -> tp.Array2d: | |
| """Combine `obj` with each in `others` using `combine_func` and concatenate. | |
| `select_and_combine_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`.""" | |
| if jitted_loop: | |
| apply_func = jit_reg.resolve(select_and_combine_nb) | |
| else: | |
| apply_func = select_and_combine | |
| return apply_and_concat( | |
| len(others), | |
| apply_func, | |
| obj, | |
| others, | |
| combine_func, | |
| *args, | |
| n_outputs=1, | |
| jitted_loop=jitted_loop, | |
| **kwargs, | |
| ) | |
| @register_jitted | |
| def combine_multiple_nb( | |
| objs: tp.Sequence, | |
| combine_func_nb: tp.Callable, | |
| *args, | |
| ) -> tp.Any: | |
| """Numba-compiled version of `combine_multiple`.""" | |
| result = objs[0] | |
| for i in range(1, len(objs)): | |
| result = combine_func_nb(result, objs[i], *args) | |
| return result | |
| def combine_multiple( | |
| objs: tp.Sequence, | |
| combine_func: tp.Callable, | |
| *args, | |
| jitted_loop: bool = False, | |
| **kwargs, | |
| ) -> tp.Any: | |
| """Combine `objs` pairwise into a single object. | |
| Set `jitted_loop` to True to use the JIT-compiled version. | |
| `combine_multiple_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`. | |
| !!! note | |
| Numba doesn't support variable keyword arguments.""" | |
| if jitted_loop: | |
| func = jit_reg.resolve(combine_multiple_nb) | |
| return func(objs, combine_func, *args) | |
| result = objs[0] | |
| for i in range(1, len(objs)): | |
| result = combine_func(result, objs[i], *args, **kwargs) | |
| return result | |
| </file> | |
| <file path="base/decorators.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Class decorators for base classes.""" | |
| from functools import cached_property as cachedproperty | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.config import Config, HybridConfig, merge_dicts | |
| __all__ = [] | |
| def override_arg_config(config: Config, merge_configs: bool = True) -> tp.ClassWrapper: | |
| """Class decorator to override the argument config of a class subclassing | |
| `vectorbtpro.base.preparing.BasePreparer`. | |
| Instead of overriding `_arg_config` class attribute, you can pass `config` directly to this decorator. | |
| Disable `merge_configs` to not merge, which will effectively disable field inheritance.""" | |
| def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: | |
| checks.assert_subclass_of(cls, "BasePreparer") | |
| if merge_configs: | |
| new_config = merge_dicts(cls.arg_config, config) | |
| else: | |
| new_config = config | |
| if not isinstance(new_config, Config): | |
| new_config = HybridConfig(new_config) | |
| setattr(cls, "_arg_config", new_config) | |
| return cls | |
| return wrapper | |
| def attach_arg_properties(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: | |
| """Class decorator to attach properties for arguments defined in the argument config | |
| of a `vectorbtpro.base.preparing.BasePreparer` subclass.""" | |
| checks.assert_subclass_of(cls, "BasePreparer") | |
| for arg_name, settings in cls.arg_config.items(): | |
| attach = settings.get("attach", None) | |
| broadcast = settings.get("broadcast", False) | |
| substitute_templates = settings.get("substitute_templates", False) | |
| if (isinstance(attach, bool) and attach) or (attach is None and (broadcast or substitute_templates)): | |
| if broadcast: | |
| return_type = tp.ArrayLike | |
| else: | |
| return_type = object | |
| target_pre_name = "_pre_" + arg_name | |
| if not hasattr(cls, target_pre_name): | |
| def pre_arg_prop(self, _arg_name: str = arg_name) -> return_type: | |
| return self.get_arg(_arg_name) | |
| pre_arg_prop.__name__ = target_pre_name | |
| pre_arg_prop.__module__ = cls.__module__ | |
| pre_arg_prop.__qualname__ = f"{cls.__name__}.{pre_arg_prop.__name__}" | |
| if broadcast and substitute_templates: | |
| pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting and template substitution." | |
| elif broadcast: | |
| pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting." | |
| else: | |
| pre_arg_prop.__doc__ = f"Argument `{arg_name}` before template substitution." | |
| setattr(cls, pre_arg_prop.__name__, cachedproperty(pre_arg_prop)) | |
| getattr(cls, pre_arg_prop.__name__).__set_name__(cls, pre_arg_prop.__name__) | |
| target_post_name = "_post_" + arg_name | |
| if not hasattr(cls, target_post_name): | |
| def post_arg_prop(self, _arg_name: str = arg_name) -> return_type: | |
| return self.prepare_post_arg(_arg_name) | |
| post_arg_prop.__name__ = target_post_name | |
| post_arg_prop.__module__ = cls.__module__ | |
| post_arg_prop.__qualname__ = f"{cls.__name__}.{post_arg_prop.__name__}" | |
| if broadcast and substitute_templates: | |
| post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting and template substitution." | |
| elif broadcast: | |
| post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting." | |
| else: | |
| post_arg_prop.__doc__ = f"Argument `{arg_name}` after template substitution." | |
| setattr(cls, post_arg_prop.__name__, cachedproperty(post_arg_prop)) | |
| getattr(cls, post_arg_prop.__name__).__set_name__(cls, post_arg_prop.__name__) | |
| target_name = arg_name | |
| if not hasattr(cls, target_name): | |
| def arg_prop(self, _target_post_name: str = target_post_name) -> return_type: | |
| return getattr(self, _target_post_name) | |
| arg_prop.__name__ = target_name | |
| arg_prop.__module__ = cls.__module__ | |
| arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}" | |
| arg_prop.__doc__ = f"Argument `{arg_name}`." | |
| setattr(cls, arg_prop.__name__, cachedproperty(arg_prop)) | |
| getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__) | |
| elif (isinstance(attach, bool) and attach) or attach is None: | |
| if not hasattr(cls, arg_name): | |
| def arg_prop(self, _arg_name: str = arg_name) -> tp.Any: | |
| return self.get_arg(_arg_name) | |
| arg_prop.__name__ = arg_name | |
| arg_prop.__module__ = cls.__module__ | |
| arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}" | |
| arg_prop.__doc__ = f"Argument `{arg_name}`." | |
| setattr(cls, arg_prop.__name__, cachedproperty(arg_prop)) | |
| getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__) | |
| return cls | |
| </file> | |
| <file path="base/flex_indexing.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Classes and functions for flexible indexing.""" | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._settings import settings | |
| from vectorbtpro.registries.jit_registry import register_jitted | |
| __all__ = [ | |
| "flex_select_1d_nb", | |
| "flex_select_1d_pr_nb", | |
| "flex_select_1d_pc_nb", | |
| "flex_select_nb", | |
| "flex_select_row_nb", | |
| "flex_select_col_nb", | |
| "flex_select_2d_row_nb", | |
| "flex_select_2d_col_nb", | |
| ] | |
| _rotate_rows = settings["indexing"]["rotate_rows"] | |
| _rotate_cols = settings["indexing"]["rotate_cols"] | |
| @register_jitted(cache=True) | |
| def flex_choose_i_1d_nb(arr: tp.FlexArray1d, i: int) -> int: | |
| """Choose a position in an array as if it has been broadcast against rows or columns. | |
| !!! note | |
| Array must be one-dimensional.""" | |
| if arr.shape[0] == 1: | |
| flex_i = 0 | |
| else: | |
| flex_i = i | |
| return int(flex_i) | |
| @register_jitted(cache=True) | |
| def flex_select_1d_nb(arr: tp.FlexArray1d, i: int) -> tp.Scalar: | |
| """Select an element of an array as if it has been broadcast against rows or columns. | |
| !!! note | |
| Array must be one-dimensional.""" | |
| flex_i = flex_choose_i_1d_nb(arr, i) | |
| return arr[flex_i] | |
| @register_jitted(cache=True) | |
| def flex_choose_i_pr_1d_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> int: | |
| """Choose a position in an array as if it has been broadcast against rows. | |
| Can use rotational indexing along rows. | |
| !!! note | |
| Array must be one-dimensional.""" | |
| if arr.shape[0] == 1: | |
| flex_i = 0 | |
| else: | |
| flex_i = i | |
| if rotate_rows: | |
| return int(flex_i) % arr.shape[0] | |
| return int(flex_i) | |
| @register_jitted(cache=True) | |
| def flex_choose_i_pr_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> int: | |
| """Choose a position in an array as if it has been broadcast against rows. | |
| Can use rotational indexing along rows. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| if arr.shape[0] == 1: | |
| flex_i = 0 | |
| else: | |
| flex_i = i | |
| if rotate_rows: | |
| return int(flex_i) % arr.shape[0] | |
| return int(flex_i) | |
| @register_jitted(cache=True) | |
| def flex_select_1d_pr_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Scalar: | |
| """Select an element of an array as if it has been broadcast against rows. | |
| Can use rotational indexing along rows. | |
| !!! note | |
| Array must be one-dimensional.""" | |
| flex_i = flex_choose_i_pr_1d_nb(arr, i, rotate_rows=rotate_rows) | |
| return arr[flex_i] | |
| @register_jitted(cache=True) | |
| def flex_choose_i_pc_1d_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> int: | |
| """Choose a position in an array as if it has been broadcast against columns. | |
| Can use rotational indexing along columns. | |
| !!! note | |
| Array must be one-dimensional.""" | |
| if arr.shape[0] == 1: | |
| flex_col = 0 | |
| else: | |
| flex_col = col | |
| if rotate_cols: | |
| return int(flex_col) % arr.shape[0] | |
| return int(flex_col) | |
| @register_jitted(cache=True) | |
| def flex_choose_i_pc_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> int: | |
| """Choose a position in an array as if it has been broadcast against columns. | |
| Can use rotational indexing along columns. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| if arr.shape[1] == 1: | |
| flex_col = 0 | |
| else: | |
| flex_col = col | |
| if rotate_cols: | |
| return int(flex_col) % arr.shape[1] | |
| return int(flex_col) | |
| @register_jitted(cache=True) | |
| def flex_select_1d_pc_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Scalar: | |
| """Select an element of an array as if it has been broadcast against columns. | |
| Can use rotational indexing along columns. | |
| !!! note | |
| Array must be one-dimensional.""" | |
| flex_col = flex_choose_i_pc_1d_nb(arr, col, rotate_cols=rotate_cols) | |
| return arr[flex_col] | |
| @register_jitted(cache=True) | |
| def flex_choose_i_and_col_nb( | |
| arr: tp.FlexArray2d, | |
| i: int, | |
| col: int, | |
| rotate_rows: bool = _rotate_rows, | |
| rotate_cols: bool = _rotate_cols, | |
| ) -> tp.Tuple[int, int]: | |
| """Choose a position in an array as if it has been broadcast rows and columns. | |
| Can use rotational indexing along rows and columns. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| if arr.shape[0] == 1: | |
| flex_i = 0 | |
| else: | |
| flex_i = i | |
| if arr.shape[1] == 1: | |
| flex_col = 0 | |
| else: | |
| flex_col = col | |
| if rotate_rows and rotate_cols: | |
| return int(flex_i) % arr.shape[0], int(flex_col) % arr.shape[1] | |
| if rotate_rows: | |
| return int(flex_i) % arr.shape[0], int(flex_col) | |
| if rotate_cols: | |
| return int(flex_i), int(flex_col) % arr.shape[1] | |
| return int(flex_i), int(flex_col) | |
| @register_jitted(cache=True) | |
| def flex_select_nb( | |
| arr: tp.FlexArray2d, | |
| i: int, | |
| col: int, | |
| rotate_rows: bool = _rotate_rows, | |
| rotate_cols: bool = _rotate_cols, | |
| ) -> tp.Scalar: | |
| """Select element of an array as if it has been broadcast rows and columns. | |
| Can use rotational indexing along rows and columns. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| flex_i, flex_col = flex_choose_i_and_col_nb( | |
| arr, | |
| i, | |
| col, | |
| rotate_rows=rotate_rows, | |
| rotate_cols=rotate_cols, | |
| ) | |
| return arr[flex_i, flex_col] | |
| @register_jitted(cache=True) | |
| def flex_select_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array1d: | |
| """Select a row from a flexible 2-dim array. Returns a 1-dim array. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows) | |
| return arr[flex_i] | |
| @register_jitted(cache=True) | |
| def flex_select_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array1d: | |
| """Select a column from a flexible 2-dim array. Returns a 1-dim array. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols) | |
| return arr[:, flex_col] | |
| @register_jitted(cache=True) | |
| def flex_select_2d_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array2d: | |
| """Select a row from a flexible 2-dim array. Returns a 2-dim array. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows) | |
| return arr[flex_i : flex_i + 1] | |
| @register_jitted(cache=True) | |
| def flex_select_2d_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array2d: | |
| """Select a column from a flexible 2-dim array. Returns a 2-dim array. | |
| !!! note | |
| Array must be two-dimensional.""" | |
| flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols) | |
| return arr[:, flex_col : flex_col + 1] | |
| </file> | |
| <file path="base/indexes.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Functions for working with indexes: index and columns. | |
| They perform operations on index objects, such as stacking, combining, and cleansing MultiIndex levels. | |
| !!! note | |
| "Index" in pandas context is referred to both index and columns.""" | |
| from datetime import datetime, timedelta | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._dtypes import * | |
| from vectorbtpro.registries.jit_registry import jit_reg, register_jitted | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.attr_ import DefineMixin, define | |
| from vectorbtpro.utils.base import Base | |
| __all__ = [ | |
| "ExceptLevel", | |
| "repeat_index", | |
| "tile_index", | |
| "stack_indexes", | |
| "combine_indexes", | |
| ] | |
| @define | |
| class ExceptLevel(DefineMixin): | |
| """Class for grouping except one or more levels.""" | |
| value: tp.MaybeLevelSequence = define.field() | |
| """One or more level positions or names.""" | |
| def to_any_index(index_like: tp.IndexLike) -> tp.Index: | |
| """Convert any index-like object to an index. | |
| Index objects are kept as-is.""" | |
| if checks.is_np_array(index_like) and index_like.ndim == 0: | |
| index_like = index_like[None] | |
| if not checks.is_index(index_like): | |
| return pd.Index(index_like) | |
| return index_like | |
| def get_index(obj: tp.SeriesFrame, axis: int) -> tp.Index: | |
| """Get index of `obj` by `axis`.""" | |
| checks.assert_instance_of(obj, (pd.Series, pd.DataFrame)) | |
| checks.assert_in(axis, (0, 1)) | |
| if axis == 0: | |
| return obj.index | |
| else: | |
| if checks.is_series(obj): | |
| if obj.name is not None: | |
| return pd.Index([obj.name]) | |
| return pd.Index([0]) # same as how pandas does it | |
| else: | |
| return obj.columns | |
| def index_from_values( | |
| values: tp.Sequence, | |
| single_value: bool = False, | |
| name: tp.Optional[tp.Hashable] = None, | |
| ) -> tp.Index: | |
| """Create a new `pd.Index` with `name` by parsing an iterable `values`. | |
| Each in `values` will correspond to an element in the new index.""" | |
| scalar_types = (int, float, complex, str, bool, datetime, timedelta, np.generic) | |
| type_id_number = {} | |
| value_names = [] | |
| if len(values) == 1: | |
| single_value = True | |
| for i in range(len(values)): | |
| if i > 0 and single_value: | |
| break | |
| v = values[i] | |
| if v is None or isinstance(v, scalar_types): | |
| value_names.append(v) | |
| elif isinstance(v, np.ndarray): | |
| all_same = False | |
| if np.issubdtype(v.dtype, np.floating): | |
| if np.isclose(v, v.item(0), equal_nan=True).all(): | |
| all_same = True | |
| elif v.dtype.names is not None: | |
| all_same = False | |
| else: | |
| if np.equal(v, v.item(0)).all(): | |
| all_same = True | |
| if all_same: | |
| value_names.append(v.item(0)) | |
| else: | |
| if single_value: | |
| value_names.append("array") | |
| else: | |
| if "array" not in type_id_number: | |
| type_id_number["array"] = {} | |
| if id(v) not in type_id_number["array"]: | |
| type_id_number["array"][id(v)] = len(type_id_number["array"]) | |
| value_names.append("array_%d" % (type_id_number["array"][id(v)])) | |
| else: | |
| type_name = str(type(v).__name__) | |
| if single_value: | |
| value_names.append("%s" % type_name) | |
| else: | |
| if type_name not in type_id_number: | |
| type_id_number[type_name] = {} | |
| if id(v) not in type_id_number[type_name]: | |
| type_id_number[type_name][id(v)] = len(type_id_number[type_name]) | |
| value_names.append("%s_%d" % (type_name, type_id_number[type_name][id(v)])) | |
| if single_value and len(values) > 1: | |
| value_names *= len(values) | |
| return pd.Index(value_names, name=name) | |
| def repeat_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index: | |
| """Repeat each element in `index` `n` times. | |
| Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`.""" | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if ignore_ranges is None: | |
| ignore_ranges = broadcasting_cfg["ignore_ranges"] | |
| index = to_any_index(index) | |
| if n == 1: | |
| return index | |
| if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name | |
| return pd.RangeIndex(start=0, stop=len(index) * n, step=1) | |
| return index.repeat(n) | |
| def tile_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index: | |
| """Tile the whole `index` `n` times. | |
| Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`.""" | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if ignore_ranges is None: | |
| ignore_ranges = broadcasting_cfg["ignore_ranges"] | |
| index = to_any_index(index) | |
| if n == 1: | |
| return index | |
| if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name | |
| return pd.RangeIndex(start=0, stop=len(index) * n, step=1) | |
| if isinstance(index, pd.MultiIndex): | |
| return pd.MultiIndex.from_tuples(np.tile(index, n), names=index.names) | |
| return pd.Index(np.tile(index, n), name=index.name) | |
| def clean_index( | |
| index: tp.IndexLike, | |
| drop_duplicates: tp.Optional[bool] = None, | |
| keep: tp.Optional[str] = None, | |
| drop_redundant: tp.Optional[bool] = None, | |
| ) -> tp.Index: | |
| """Clean index. | |
| Set `drop_duplicates` to True to remove duplicate levels. | |
| For details on `keep`, see `drop_duplicate_levels`. | |
| Set `drop_redundant` to True to use `drop_redundant_levels`.""" | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if drop_duplicates is None: | |
| drop_duplicates = broadcasting_cfg["drop_duplicates"] | |
| if keep is None: | |
| keep = broadcasting_cfg["keep"] | |
| if drop_redundant is None: | |
| drop_redundant = broadcasting_cfg["drop_redundant"] | |
| index = to_any_index(index) | |
| if drop_duplicates: | |
| index = drop_duplicate_levels(index, keep=keep) | |
| if drop_redundant: | |
| index = drop_redundant_levels(index) | |
| return index | |
| def stack_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **clean_index_kwargs) -> tp.Index: | |
| """Stack each index in `indexes` on top of each other, from top to bottom.""" | |
| if len(indexes) == 1: | |
| indexes = indexes[0] | |
| indexes = list(indexes) | |
| levels = [] | |
| for i in range(len(indexes)): | |
| index = indexes[i] | |
| if not isinstance(index, pd.MultiIndex): | |
| levels.append(to_any_index(index)) | |
| else: | |
| for j in range(index.nlevels): | |
| levels.append(index.get_level_values(j)) | |
| max_len = max(map(len, levels)) | |
| for i in range(len(levels)): | |
| if len(levels[i]) < max_len: | |
| if len(levels[i]) != 1: | |
| raise ValueError(f"Index at level {i} could not be broadcast to shape ({max_len},) ") | |
| levels[i] = repeat_index(levels[i], max_len, ignore_ranges=False) | |
| new_index = pd.MultiIndex.from_arrays(levels) | |
| return clean_index(new_index, **clean_index_kwargs) | |
| def combine_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **kwargs) -> tp.Index: | |
| """Combine each index in `indexes` using Cartesian product. | |
| Keyword arguments will be passed to `stack_indexes`.""" | |
| if len(indexes) == 1: | |
| indexes = indexes[0] | |
| indexes = list(indexes) | |
| new_index = to_any_index(indexes[0]) | |
| for i in range(1, len(indexes)): | |
| index1, index2 = new_index, to_any_index(indexes[i]) | |
| new_index1 = repeat_index(index1, len(index2), ignore_ranges=False) | |
| new_index2 = tile_index(index2, len(index1), ignore_ranges=False) | |
| new_index = stack_indexes([new_index1, new_index2], **kwargs) | |
| return new_index | |
| def combine_index_with_keys(index: tp.IndexLike, keys: tp.IndexLike, lens: tp.Sequence[int], **kwargs) -> tp.Index: | |
| """Build keys based on index lengths.""" | |
| if not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| if not isinstance(keys, pd.Index): | |
| keys = pd.Index(keys) | |
| new_index = None | |
| new_keys = None | |
| start_idx = 0 | |
| for i in range(len(keys)): | |
| _index = index[start_idx : start_idx + lens[i]] | |
| if new_index is None: | |
| new_index = _index | |
| else: | |
| new_index = new_index.append(_index) | |
| start_idx += lens[i] | |
| new_key = keys[[i]].repeat(lens[i]) | |
| if new_keys is None: | |
| new_keys = new_key | |
| else: | |
| new_keys = new_keys.append(new_key) | |
| return stack_indexes([new_keys, new_index], **kwargs) | |
| def concat_indexes( | |
| *indexes: tp.MaybeTuple[tp.IndexLike], | |
| index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append", | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| clean_index_kwargs: tp.KwargsLike = None, | |
| verify_integrity: bool = True, | |
| axis: int = 1, | |
| ) -> tp.Index: | |
| """Concatenate indexes. | |
| The following index concatenation methods are supported: | |
| * 'append': append one index to another | |
| * 'union': build a union of indexes | |
| * 'pd_concat': convert indexes to Pandas Series or DataFrames and use `pd.concat` | |
| * 'factorize': factorize the concatenated index | |
| * 'factorize_each': factorize each index and concatenate while keeping numbers unique | |
| * 'reset': reset the concatenated index without applying `keys` | |
| * Callable: a custom callable that takes the indexes and returns the concatenated index | |
| Argument `index_concat_method` also accepts a tuple of two options: the second option gets applied | |
| if the first one fails. | |
| Use `keys` as an index with the same number of elements as there are indexes to add | |
| another index level on top of the concatenated indexes. | |
| If `verify_integrity` is True and `keys` is None, performs various checks depending on the axis.""" | |
| if len(indexes) == 1: | |
| indexes = indexes[0] | |
| indexes = list(indexes) | |
| if clean_index_kwargs is None: | |
| clean_index_kwargs = {} | |
| if axis == 0: | |
| factorized_name = "row_idx" | |
| elif axis == 1: | |
| factorized_name = "col_idx" | |
| else: | |
| factorized_name = "group_idx" | |
| if keys is None: | |
| all_ranges = True | |
| for index in indexes: | |
| if not checks.is_default_index(index): | |
| all_ranges = False | |
| break | |
| if all_ranges: | |
| return pd.RangeIndex(stop=sum(map(len, indexes))) | |
| if isinstance(index_concat_method, tuple): | |
| try: | |
| return concat_indexes( | |
| *indexes, | |
| index_concat_method=index_concat_method[0], | |
| keys=keys, | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=verify_integrity, | |
| axis=axis, | |
| ) | |
| except Exception as e: | |
| return concat_indexes( | |
| *indexes, | |
| index_concat_method=index_concat_method[1], | |
| keys=keys, | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=verify_integrity, | |
| axis=axis, | |
| ) | |
| if not isinstance(index_concat_method, str): | |
| new_index = index_concat_method(indexes) | |
| elif index_concat_method.lower() == "append": | |
| new_index = None | |
| for index in indexes: | |
| if new_index is None: | |
| new_index = index | |
| else: | |
| new_index = new_index.append(index) | |
| elif index_concat_method.lower() == "union": | |
| if keys is not None: | |
| raise ValueError("Cannot apply keys after concatenating indexes through union") | |
| new_index = None | |
| for index in indexes: | |
| if new_index is None: | |
| new_index = index | |
| else: | |
| new_index = new_index.union(index) | |
| elif index_concat_method.lower() == "pd_concat": | |
| new_index = None | |
| for index in indexes: | |
| if isinstance(index, pd.MultiIndex): | |
| index = index.to_frame().reset_index(drop=True) | |
| else: | |
| index = index.to_series().reset_index(drop=True) | |
| if new_index is None: | |
| new_index = index | |
| else: | |
| if isinstance(new_index, pd.DataFrame): | |
| if isinstance(index, pd.Series): | |
| index = index.to_frame() | |
| elif isinstance(index, pd.Series): | |
| if isinstance(new_index, pd.DataFrame): | |
| new_index = new_index.to_frame() | |
| new_index = pd.concat((new_index, index), ignore_index=True) | |
| if isinstance(new_index, pd.Series): | |
| new_index = pd.Index(new_index) | |
| else: | |
| new_index = pd.MultiIndex.from_frame(new_index) | |
| elif index_concat_method.lower() == "factorize": | |
| new_index = concat_indexes( | |
| *indexes, | |
| index_concat_method="append", | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=False, | |
| axis=axis, | |
| ) | |
| new_index = pd.Index(pd.factorize(new_index)[0], name=factorized_name) | |
| elif index_concat_method.lower() == "factorize_each": | |
| new_index = None | |
| for index in indexes: | |
| index = pd.Index(pd.factorize(index)[0], name=factorized_name) | |
| if new_index is None: | |
| new_index = index | |
| next_min = index.max() + 1 | |
| else: | |
| new_index = new_index.append(index + next_min) | |
| next_min = index.max() + 1 + next_min | |
| elif index_concat_method.lower() == "reset": | |
| return pd.RangeIndex(stop=sum(map(len, indexes))) | |
| else: | |
| if axis == 0: | |
| raise ValueError(f"Invalid index concatenation method: '{index_concat_method}'") | |
| elif axis == 1: | |
| raise ValueError(f"Invalid column concatenation method: '{index_concat_method}'") | |
| else: | |
| raise ValueError(f"Invalid group concatenation method: '{index_concat_method}'") | |
| if keys is not None: | |
| if isinstance(keys[0], pd.Index): | |
| keys = concat_indexes( | |
| *keys, | |
| index_concat_method="append", | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=False, | |
| axis=axis, | |
| ) | |
| new_index = stack_indexes((keys, new_index), **clean_index_kwargs) | |
| keys = None | |
| elif not isinstance(keys, pd.Index): | |
| keys = pd.Index(keys) | |
| if keys is not None: | |
| top_index = None | |
| for i, index in enumerate(indexes): | |
| repeated_index = repeat_index(keys[[i]], len(index)) | |
| if top_index is None: | |
| top_index = repeated_index | |
| else: | |
| top_index = top_index.append(repeated_index) | |
| new_index = stack_indexes((top_index, new_index), **clean_index_kwargs) | |
| if verify_integrity: | |
| if keys is None: | |
| if axis == 0: | |
| if not new_index.is_monotonic_increasing: | |
| raise ValueError("Concatenated index is not monotonically increasing") | |
| if "mixed" in new_index.inferred_type: | |
| raise ValueError("Concatenated index is mixed") | |
| if new_index.has_duplicates: | |
| raise ValueError("Concatenated index contains duplicates") | |
| if axis == 1: | |
| if new_index.has_duplicates: | |
| raise ValueError("Concatenated columns contain duplicates") | |
| if axis == 2: | |
| if new_index.has_duplicates: | |
| len_sum = 0 | |
| for index in indexes: | |
| if len_sum > 0: | |
| prev_index = new_index[:len_sum] | |
| this_index = new_index[len_sum : len_sum + len(index)] | |
| if len(prev_index.intersection(this_index)) > 0: | |
| raise ValueError("Concatenated groups contain duplicates") | |
| len_sum += len(index) | |
| return new_index | |
| def drop_levels( | |
| index: tp.Index, | |
| levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence], | |
| strict: bool = True, | |
| ) -> tp.Index: | |
| """Drop `levels` in `index` by their name(s)/position(s). | |
| Provide `levels` as an instance of `ExceptLevel` to drop everything apart from the specified levels.""" | |
| if not isinstance(index, pd.MultiIndex): | |
| if strict: | |
| raise TypeError("Index must be a multi-index") | |
| return index | |
| if isinstance(levels, ExceptLevel): | |
| levels = levels.value | |
| except_mode = True | |
| else: | |
| except_mode = False | |
| levels_to_drop = set() | |
| if isinstance(levels, str) or not checks.is_sequence(levels): | |
| levels = [levels] | |
| for level in levels: | |
| if level in index.names: | |
| for level_pos in [i for i, x in enumerate(index.names) if x == level]: | |
| levels_to_drop.add(level_pos) | |
| elif checks.is_int(level): | |
| if level < 0: | |
| new_level = index.nlevels + level | |
| if new_level < 0: | |
| raise KeyError(f"Level at position {level} not found") | |
| level = new_level | |
| if 0 <= level < index.nlevels: | |
| levels_to_drop.add(level) | |
| else: | |
| raise KeyError(f"Level at position {level} not found") | |
| elif strict: | |
| raise KeyError(f"Level '{level}' not found") | |
| if except_mode: | |
| levels_to_drop = set(range(index.nlevels)).difference(levels_to_drop) | |
| if len(levels_to_drop) == 0: | |
| if strict: | |
| raise ValueError("No levels to drop") | |
| return index | |
| if len(levels_to_drop) >= index.nlevels: | |
| if strict: | |
| raise ValueError( | |
| f"Cannot remove {len(levels_to_drop)} levels from an index with {index.nlevels} levels: " | |
| "at least one level must be left" | |
| ) | |
| return index | |
| return index.droplevel(list(levels_to_drop)) | |
| def rename_levels(index: tp.Index, mapper: tp.MaybeMappingSequence[tp.Level], strict: bool = True) -> tp.Index: | |
| """Rename levels in `index` by `mapper`. | |
| Mapper can be a single or multiple levels to rename to, or a dictionary that maps | |
| old level names to new level names.""" | |
| if isinstance(index, pd.MultiIndex): | |
| nlevels = index.nlevels | |
| if isinstance(mapper, (int, str)): | |
| mapper = dict(zip(index.names, [mapper])) | |
| elif checks.is_complex_sequence(mapper): | |
| mapper = dict(zip(index.names, mapper)) | |
| else: | |
| nlevels = 1 | |
| if isinstance(mapper, (int, str)): | |
| mapper = dict(zip([index.name], [mapper])) | |
| elif checks.is_complex_sequence(mapper): | |
| mapper = dict(zip([index.name], mapper)) | |
| for k, v in mapper.items(): | |
| if k in index.names: | |
| if isinstance(index, pd.MultiIndex): | |
| index = index.rename(v, level=k) | |
| else: | |
| index = index.rename(v) | |
| elif checks.is_int(k): | |
| if k < 0: | |
| new_k = nlevels + k | |
| if new_k < 0: | |
| raise KeyError(f"Level at position {k} not found") | |
| k = new_k | |
| if 0 <= k < nlevels: | |
| if isinstance(index, pd.MultiIndex): | |
| index = index.rename(v, level=k) | |
| else: | |
| index = index.rename(v) | |
| else: | |
| raise KeyError(f"Level at position {k} not found") | |
| elif strict: | |
| raise KeyError(f"Level '{k}' not found") | |
| return index | |
| def select_levels( | |
| index: tp.Index, | |
| levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence], | |
| strict: bool = True, | |
| ) -> tp.Index: | |
| """Build a new index by selecting one or multiple `levels` from `index`. | |
| Provide `levels` as an instance of `ExceptLevel` to select everything apart from the specified levels.""" | |
| was_multiindex = True | |
| if not isinstance(index, pd.MultiIndex): | |
| was_multiindex = False | |
| index = pd.MultiIndex.from_arrays([index]) | |
| if isinstance(levels, ExceptLevel): | |
| levels = levels.value | |
| except_mode = True | |
| else: | |
| except_mode = False | |
| levels_to_select = list() | |
| if isinstance(levels, str) or not checks.is_sequence(levels): | |
| levels = [levels] | |
| single_mode = True | |
| else: | |
| single_mode = False | |
| for level in levels: | |
| if level in index.names: | |
| for level_pos in [i for i, x in enumerate(index.names) if x == level]: | |
| if level_pos not in levels_to_select: | |
| levels_to_select.append(level_pos) | |
| elif checks.is_int(level): | |
| if level < 0: | |
| new_level = index.nlevels + level | |
| if new_level < 0: | |
| raise KeyError(f"Level at position {level} not found") | |
| level = new_level | |
| if 0 <= level < index.nlevels: | |
| if level not in levels_to_select: | |
| levels_to_select.append(level) | |
| else: | |
| raise KeyError(f"Level at position {level} not found") | |
| elif strict: | |
| raise KeyError(f"Level '{level}' not found") | |
| if except_mode: | |
| levels_to_select = list(set(range(index.nlevels)).difference(levels_to_select)) | |
| if len(levels_to_select) == 0: | |
| if strict: | |
| raise ValueError("No levels to select") | |
| if not was_multiindex: | |
| return index.get_level_values(0) | |
| return index | |
| if len(levels_to_select) == 1 and single_mode: | |
| return index.get_level_values(levels_to_select[0]) | |
| levels = [index.get_level_values(level) for level in levels_to_select] | |
| return pd.MultiIndex.from_arrays(levels) | |
| def drop_redundant_levels(index: tp.Index) -> tp.Index: | |
| """Drop levels in `index` that either have a single unnamed value or a range from 0 to n.""" | |
| if not isinstance(index, pd.MultiIndex): | |
| return index | |
| levels_to_drop = [] | |
| for i in range(index.nlevels): | |
| if len(index.levels[i]) == 1 and index.levels[i].name is None: | |
| levels_to_drop.append(i) | |
| elif checks.is_default_index(index.get_level_values(i)): | |
| levels_to_drop.append(i) | |
| if len(levels_to_drop) < index.nlevels: | |
| return index.droplevel(levels_to_drop) | |
| return index | |
| def drop_duplicate_levels(index: tp.Index, keep: tp.Optional[str] = None) -> tp.Index: | |
| """Drop levels in `index` with the same name and values. | |
| Set `keep` to 'last' to keep last levels, otherwise 'first'. | |
| Set `keep` to None to use the default.""" | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if keep is None: | |
| keep = broadcasting_cfg["keep"] | |
| if not isinstance(index, pd.MultiIndex): | |
| return index | |
| checks.assert_in(keep.lower(), ["first", "last"]) | |
| levels_to_drop = set() | |
| level_values = [index.get_level_values(i) for i in range(index.nlevels)] | |
| for i in range(index.nlevels): | |
| level1 = level_values[i] | |
| for j in range(i + 1, index.nlevels): | |
| level2 = level_values[j] | |
| if level1.name is None or level2.name is None or level1.name == level2.name: | |
| if checks.is_index_equal(level1, level2, check_names=False): | |
| if level1.name is None and level2.name is not None: | |
| levels_to_drop.add(i) | |
| elif level1.name is not None and level2.name is None: | |
| levels_to_drop.add(j) | |
| else: | |
| if keep.lower() == "first": | |
| levels_to_drop.add(j) | |
| else: | |
| levels_to_drop.add(i) | |
| return index.droplevel(list(levels_to_drop)) | |
| @register_jitted(cache=True) | |
| def align_arr_indices_nb(a: tp.Array1d, b: tp.Array1d) -> tp.Array1d: | |
| """Return indices required to align `a` to `b`.""" | |
| idxs = np.empty(b.shape[0], dtype=int_) | |
| g = 0 | |
| for i in range(b.shape[0]): | |
| for j in range(a.shape[0]): | |
| if b[i] == a[j]: | |
| idxs[g] = j | |
| g += 1 | |
| break | |
| return idxs | |
| def align_index_to(index1: tp.Index, index2: tp.Index, jitted: tp.JittedOption = None) -> tp.IndexSlice: | |
| """Align `index1` to have the same shape as `index2` if they have any levels in common. | |
| Returns index slice for the aligning.""" | |
| if not isinstance(index1, pd.MultiIndex): | |
| index1 = pd.MultiIndex.from_arrays([index1]) | |
| if not isinstance(index2, pd.MultiIndex): | |
| index2 = pd.MultiIndex.from_arrays([index2]) | |
| if checks.is_index_equal(index1, index2): | |
| return pd.IndexSlice[:] | |
| if len(index1) > len(index2): | |
| raise ValueError("Longer index cannot be aligned to shorter index") | |
| mapper = {} | |
| for i in range(index1.nlevels): | |
| name1 = index1.names[i] | |
| for j in range(index2.nlevels): | |
| name2 = index2.names[j] | |
| if name1 is None or name2 is None or name1 == name2: | |
| if set(index2.levels[j]).issubset(set(index1.levels[i])): | |
| if i in mapper: | |
| raise ValueError(f"There are multiple candidate levels with name {name1} in second index") | |
| mapper[i] = j | |
| continue | |
| if name1 == name2 and name1 is not None: | |
| raise ValueError(f"Level {name1} in second index contains values not in first index") | |
| if len(mapper) == 0: | |
| if len(index1) == len(index2): | |
| return pd.IndexSlice[:] | |
| raise ValueError("Cannot find common levels to align indexes") | |
| factorized = [] | |
| for k, v in mapper.items(): | |
| factorized.append( | |
| pd.factorize( | |
| pd.concat( | |
| ( | |
| index1.get_level_values(k).to_series(), | |
| index2.get_level_values(v).to_series(), | |
| ) | |
| ) | |
| )[0], | |
| ) | |
| stacked = np.transpose(np.stack(factorized)) | |
| indices1 = stacked[: len(index1)] | |
| indices2 = stacked[len(index1) :] | |
| if len(indices1) < len(indices2): | |
| if len(np.unique(indices1, axis=0)) != len(indices1): | |
| raise ValueError("Cannot align indexes") | |
| if len(index2) % len(index1) == 0: | |
| tile_times = len(index2) // len(index1) | |
| index1_tiled = np.tile(indices1, (tile_times, 1)) | |
| if np.array_equal(index1_tiled, indices2): | |
| return pd.IndexSlice[np.tile(np.arange(len(index1)), tile_times)] | |
| unique_indices = np.unique(stacked, axis=0, return_inverse=True)[1] | |
| unique1 = unique_indices[: len(index1)] | |
| unique2 = unique_indices[len(index1) :] | |
| if len(indices1) == len(indices2): | |
| if np.array_equal(unique1, unique2): | |
| return pd.IndexSlice[:] | |
| func = jit_reg.resolve_option(align_arr_indices_nb, jitted) | |
| return pd.IndexSlice[func(unique1, unique2)] | |
| def align_indexes( | |
| *indexes: tp.MaybeTuple[tp.Index], | |
| return_new_index: bool = False, | |
| **kwargs, | |
| ) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]: | |
| """Align multiple indexes to each other with `align_index_to`.""" | |
| if len(indexes) == 1: | |
| indexes = indexes[0] | |
| indexes = list(indexes) | |
| index_items = sorted([(i, indexes[i]) for i in range(len(indexes))], key=lambda x: len(x[1])) | |
| index_slices = [] | |
| for i in range(len(index_items)): | |
| index_slice = align_index_to(index_items[i][1], index_items[-1][1], **kwargs) | |
| index_slices.append((index_items[i][0], index_slice)) | |
| index_slices = list(map(lambda x: x[1], sorted(index_slices, key=lambda x: x[0]))) | |
| if return_new_index: | |
| new_index = stack_indexes( | |
| *[indexes[i][index_slices[i]] for i in range(len(indexes))], | |
| drop_duplicates=True, | |
| ) | |
| return tuple(index_slices), new_index | |
| return tuple(index_slices) | |
| @register_jitted(cache=True) | |
| def block_index_product_nb( | |
| block_group_map1: tp.GroupMap, | |
| block_group_map2: tp.GroupMap, | |
| factorized1: tp.Array1d, | |
| factorized2: tp.Array1d, | |
| ) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Return indices required for building a block-wise Cartesian product of two factorized indexes.""" | |
| group_idxs1, group_lens1 = block_group_map1 | |
| group_idxs2, group_lens2 = block_group_map2 | |
| group_start_idxs1 = np.cumsum(group_lens1) - group_lens1 | |
| group_start_idxs2 = np.cumsum(group_lens2) - group_lens2 | |
| matched1 = np.empty(len(factorized1), dtype=np.bool_) | |
| matched2 = np.empty(len(factorized2), dtype=np.bool_) | |
| indices1 = np.empty(len(factorized1) * len(factorized2), dtype=int_) | |
| indices2 = np.empty(len(factorized1) * len(factorized2), dtype=int_) | |
| k1 = 0 | |
| k2 = 0 | |
| for g1 in range(len(group_lens1)): | |
| group_len1 = group_lens1[g1] | |
| group_start1 = group_start_idxs1[g1] | |
| for g2 in range(len(group_lens2)): | |
| group_len2 = group_lens2[g2] | |
| group_start2 = group_start_idxs2[g2] | |
| for c1 in range(group_len1): | |
| i = group_idxs1[group_start1 + c1] | |
| for c2 in range(group_len2): | |
| j = group_idxs2[group_start2 + c2] | |
| if factorized1[i] == factorized2[j]: | |
| matched1[i] = True | |
| matched2[j] = True | |
| indices1[k1] = i | |
| indices2[k2] = j | |
| k1 += 1 | |
| k2 += 1 | |
| if not np.all(matched1) or not np.all(matched2): | |
| raise ValueError("Cannot match some block level values") | |
| return indices1[:k1], indices2[:k2] | |
| def cross_index_with( | |
| index1: tp.Index, | |
| index2: tp.Index, | |
| return_new_index: bool = False, | |
| ) -> tp.Union[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Tuple[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Index]]: | |
| """Build a Cartesian product of one index with another while taking into account levels they have in common. | |
| Returns index slices for the aligning.""" | |
| from vectorbtpro.base.grouping.nb import get_group_map_nb | |
| index1_default = checks.is_default_index(index1, check_names=True) | |
| index2_default = checks.is_default_index(index2, check_names=True) | |
| if not isinstance(index1, pd.MultiIndex): | |
| index1 = pd.MultiIndex.from_arrays([index1]) | |
| if not isinstance(index2, pd.MultiIndex): | |
| index2 = pd.MultiIndex.from_arrays([index2]) | |
| if not index1_default and not index2_default and checks.is_index_equal(index1, index2): | |
| if return_new_index: | |
| new_index = stack_indexes(index1, index2, drop_duplicates=True) | |
| return (pd.IndexSlice[:], pd.IndexSlice[:]), new_index | |
| return pd.IndexSlice[:], pd.IndexSlice[:] | |
| levels1 = [] | |
| levels2 = [] | |
| for i in range(index1.nlevels): | |
| if checks.is_default_index(index1.get_level_values(i), check_names=True): | |
| continue | |
| for j in range(index2.nlevels): | |
| if checks.is_default_index(index2.get_level_values(j), check_names=True): | |
| continue | |
| name1 = index1.names[i] | |
| name2 = index2.names[j] | |
| if name1 == name2: | |
| if set(index2.levels[j]) == set(index1.levels[i]): | |
| if i in levels1 or j in levels2: | |
| raise ValueError(f"There are multiple candidate block levels with name {name1}") | |
| levels1.append(i) | |
| levels2.append(j) | |
| continue | |
| if name1 is not None: | |
| raise ValueError(f"Candidate block level {name1} in both indexes has different values") | |
| if len(levels1) == 0: | |
| # Regular index product | |
| indices1 = np.repeat(np.arange(len(index1)), len(index2)) | |
| indices2 = np.tile(np.arange(len(index2)), len(index1)) | |
| else: | |
| # Block index product | |
| index_levels1 = select_levels(index1, levels1) | |
| index_levels2 = select_levels(index2, levels2) | |
| block_levels1 = list(set(range(index1.nlevels)).difference(levels1)) | |
| block_levels2 = list(set(range(index2.nlevels)).difference(levels2)) | |
| if len(block_levels1) > 0: | |
| index_block_levels1 = select_levels(index1, block_levels1) | |
| else: | |
| index_block_levels1 = pd.Index(np.full(len(index1), 0)) | |
| if len(block_levels2) > 0: | |
| index_block_levels2 = select_levels(index2, block_levels2) | |
| else: | |
| index_block_levels2 = pd.Index(np.full(len(index2), 0)) | |
| factorized = pd.factorize(pd.concat((index_levels1.to_series(), index_levels2.to_series())))[0] | |
| factorized1 = factorized[: len(index_levels1)] | |
| factorized2 = factorized[len(index_levels1) :] | |
| block_factorized1, block_unique1 = pd.factorize(index_block_levels1) | |
| block_factorized2, block_unique2 = pd.factorize(index_block_levels2) | |
| block_group_map1 = get_group_map_nb(block_factorized1, len(block_unique1)) | |
| block_group_map2 = get_group_map_nb(block_factorized2, len(block_unique2)) | |
| indices1, indices2 = block_index_product_nb( | |
| block_group_map1, | |
| block_group_map2, | |
| factorized1, | |
| factorized2, | |
| ) | |
| if return_new_index: | |
| new_index = stack_indexes(index1[indices1], index2[indices2], drop_duplicates=True) | |
| return (pd.IndexSlice[indices1], pd.IndexSlice[indices2]), new_index | |
| return pd.IndexSlice[indices1], pd.IndexSlice[indices2] | |
| def cross_indexes( | |
| *indexes: tp.MaybeTuple[tp.Index], | |
| return_new_index: bool = False, | |
| ) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]: | |
| """Cross multiple indexes with `cross_index_with`.""" | |
| if len(indexes) == 1: | |
| indexes = indexes[0] | |
| indexes = list(indexes) | |
| if len(indexes) == 2: | |
| return cross_index_with(indexes[0], indexes[1], return_new_index=return_new_index) | |
| index = None | |
| index_slices = [] | |
| for i in range(len(indexes) - 2, -1, -1): | |
| index1 = indexes[i] | |
| if i == len(indexes) - 2: | |
| index2 = indexes[i + 1] | |
| else: | |
| index2 = index | |
| (index_slice1, index_slice2), index = cross_index_with(index1, index2, return_new_index=True) | |
| if i == len(indexes) - 2: | |
| index_slices.append(index_slice2) | |
| else: | |
| for j in range(len(index_slices)): | |
| if isinstance(index_slices[j], slice): | |
| index_slices[j] = np.arange(len(index2))[index_slices[j]] | |
| index_slices[j] = index_slices[j][index_slice2] | |
| index_slices.append(index_slice1) | |
| if return_new_index: | |
| return tuple(index_slices[::-1]), index | |
| return tuple(index_slices[::-1]) | |
| OptionalLevelSequence = tp.Optional[tp.Sequence[tp.Union[None, tp.Level]]] | |
| def pick_levels( | |
| index: tp.Index, | |
| required_levels: OptionalLevelSequence = None, | |
| optional_levels: OptionalLevelSequence = None, | |
| ) -> tp.Tuple[tp.List[int], tp.List[int]]: | |
| """Pick optional and required levels and return their indices. | |
| Raises an exception if index has less or more levels than expected.""" | |
| if required_levels is None: | |
| required_levels = [] | |
| if optional_levels is None: | |
| optional_levels = [] | |
| checks.assert_instance_of(index, pd.MultiIndex) | |
| n_opt_set = len(list(filter(lambda x: x is not None, optional_levels))) | |
| n_req_set = len(list(filter(lambda x: x is not None, required_levels))) | |
| n_levels_left = index.nlevels - n_opt_set | |
| if n_req_set < len(required_levels): | |
| if n_levels_left != len(required_levels): | |
| n_expected = len(required_levels) + n_opt_set | |
| raise ValueError(f"Expected {n_expected} levels, found {index.nlevels}") | |
| levels_left = list(range(index.nlevels)) | |
| _optional_levels = [] | |
| for level in optional_levels: | |
| level_pos = None | |
| if level is not None: | |
| checks.assert_instance_of(level, (int, str)) | |
| if isinstance(level, str): | |
| level_pos = index.names.index(level) | |
| else: | |
| level_pos = level | |
| if level_pos < 0: | |
| level_pos = index.nlevels + level_pos | |
| levels_left.remove(level_pos) | |
| _optional_levels.append(level_pos) | |
| _required_levels = [] | |
| for level in required_levels: | |
| level_pos = None | |
| if level is not None: | |
| checks.assert_instance_of(level, (int, str)) | |
| if isinstance(level, str): | |
| level_pos = index.names.index(level) | |
| else: | |
| level_pos = level | |
| if level_pos < 0: | |
| level_pos = index.nlevels + level_pos | |
| levels_left.remove(level_pos) | |
| _required_levels.append(level_pos) | |
| for i, level in enumerate(_required_levels): | |
| if level is None: | |
| _required_levels[i] = levels_left.pop(0) | |
| return _required_levels, _optional_levels | |
| def find_first_occurrence(index_value: tp.Any, index: tp.Index) -> int: | |
| """Return index of the first occurrence in `index`.""" | |
| loc = index.get_loc(index_value) | |
| if isinstance(loc, slice): | |
| return loc.start | |
| elif isinstance(loc, list): | |
| return loc[0] | |
| elif isinstance(loc, np.ndarray): | |
| return np.flatnonzero(loc)[0] | |
| return loc | |
| IndexApplierT = tp.TypeVar("IndexApplierT", bound="IndexApplier") | |
| class IndexApplier(Base): | |
| """Abstract class that can apply a function on an index.""" | |
| def apply_to_index(self: IndexApplierT, apply_func: tp.Callable, *args, **kwargs) -> IndexApplierT: | |
| """Apply function `apply_func` on the index of the instance and return a new instance.""" | |
| raise NotImplementedError | |
| def add_levels( | |
| self: IndexApplierT, | |
| *indexes: tp.Index, | |
| on_top: bool = True, | |
| drop_duplicates: tp.Optional[bool] = None, | |
| keep: tp.Optional[str] = None, | |
| drop_redundant: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> IndexApplierT: | |
| """Append or prepend levels using `stack_indexes`. | |
| Set `on_top` to False to stack at bottom. | |
| See `IndexApplier.apply_to_index` for other keyword arguments.""" | |
| def _apply_func(index): | |
| if on_top: | |
| return stack_indexes( | |
| [*indexes, index], | |
| drop_duplicates=drop_duplicates, | |
| keep=keep, | |
| drop_redundant=drop_redundant, | |
| ) | |
| return stack_indexes( | |
| [index, *indexes], | |
| drop_duplicates=drop_duplicates, | |
| keep=keep, | |
| drop_redundant=drop_redundant, | |
| ) | |
| return self.apply_to_index(_apply_func, **kwargs) | |
| def drop_levels( | |
| self: IndexApplierT, | |
| levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence], | |
| strict: bool = True, | |
| **kwargs, | |
| ) -> IndexApplierT: | |
| """Drop levels using `drop_levels`. | |
| See `IndexApplier.apply_to_index` for other keyword arguments.""" | |
| def _apply_func(index): | |
| return drop_levels(index, levels, strict=strict) | |
| return self.apply_to_index(_apply_func, **kwargs) | |
| def rename_levels( | |
| self: IndexApplierT, | |
| mapper: tp.MaybeMappingSequence[tp.Level], | |
| strict: bool = True, | |
| **kwargs, | |
| ) -> IndexApplierT: | |
| """Rename levels using `rename_levels`. | |
| See `IndexApplier.apply_to_index` for other keyword arguments.""" | |
| def _apply_func(index): | |
| return rename_levels(index, mapper, strict=strict) | |
| return self.apply_to_index(_apply_func, **kwargs) | |
| def select_levels( | |
| self: IndexApplierT, | |
| level_names: tp.Union[ExceptLevel, tp.MaybeLevelSequence], | |
| strict: bool = True, | |
| **kwargs, | |
| ) -> IndexApplierT: | |
| """Select levels using `select_levels`. | |
| See `IndexApplier.apply_to_index` for other keyword arguments.""" | |
| def _apply_func(index): | |
| return select_levels(index, level_names, strict=strict) | |
| return self.apply_to_index(_apply_func, **kwargs) | |
| def drop_redundant_levels(self: IndexApplierT, **kwargs) -> IndexApplierT: | |
| """Drop any redundant levels using `drop_redundant_levels`. | |
| See `IndexApplier.apply_to_index` for other keyword arguments.""" | |
| def _apply_func(index): | |
| return drop_redundant_levels(index) | |
| return self.apply_to_index(_apply_func, **kwargs) | |
| def drop_duplicate_levels(self: IndexApplierT, keep: tp.Optional[str] = None, **kwargs) -> IndexApplierT: | |
| """Drop any duplicate levels using `drop_duplicate_levels`. | |
| See `IndexApplier.apply_to_index` for other keyword arguments.""" | |
| def _apply_func(index): | |
| return drop_duplicate_levels(index, keep=keep) | |
| return self.apply_to_index(_apply_func, **kwargs) | |
| </file> | |
| <file path="base/indexing.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Classes and functions for indexing.""" | |
| import functools | |
| from datetime import time | |
| from functools import partial | |
| import numpy as np | |
| import pandas as pd | |
| from pandas.tseries.offsets import BaseOffset | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._dtypes import * | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils import checks, datetime_ as dt, datetime_nb as dt_nb | |
| from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING | |
| from vectorbtpro.utils.base import Base | |
| from vectorbtpro.utils.config import hdict, merge_dicts | |
| from vectorbtpro.utils.mapping import to_field_mapping | |
| from vectorbtpro.utils.pickling import pdict | |
| from vectorbtpro.utils.selection import PosSel, LabelSel | |
| from vectorbtpro.utils.template import CustomTemplate | |
| __all__ = [ | |
| "PandasIndexer", | |
| "ExtPandasIndexer", | |
| "hslice", | |
| "get_index_points", | |
| "get_index_ranges", | |
| "get_idxs", | |
| "index_dict", | |
| "IdxSetter", | |
| "IdxSetterFactory", | |
| "IdxDict", | |
| "IdxSeries", | |
| "IdxFrame", | |
| "IdxRecords", | |
| "posidx", | |
| "maskidx", | |
| "lbidx", | |
| "dtidx", | |
| "dtcidx", | |
| "pointidx", | |
| "rangeidx", | |
| "autoidx", | |
| "rowidx", | |
| "colidx", | |
| "idx", | |
| ] | |
| __pdoc__ = {} | |
| class IndexingError(Exception): | |
| """Exception raised when an indexing error has occurred.""" | |
| IndexingBaseT = tp.TypeVar("IndexingBaseT", bound="IndexingBase") | |
| class IndexingBase(Base): | |
| """Class that supports indexing through `IndexingBase.indexing_func`.""" | |
| def indexing_func(self: IndexingBaseT, pd_indexing_func: tp.Callable, **kwargs) -> IndexingBaseT: | |
| """Apply `pd_indexing_func` on all pandas objects in question and return a new instance of the class. | |
| Should be overridden.""" | |
| raise NotImplementedError | |
| def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None: | |
| """Apply `pd_indexing_setter_func` on all pandas objects in question. | |
| Should be overridden.""" | |
| raise NotImplementedError | |
| class LocBase(Base): | |
| """Class that implements location-based indexing.""" | |
| def __init__( | |
| self, | |
| indexing_func: tp.Callable, | |
| indexing_setter_func: tp.Optional[tp.Callable] = None, | |
| **kwargs, | |
| ) -> None: | |
| self._indexing_func = indexing_func | |
| self._indexing_setter_func = indexing_setter_func | |
| self._indexing_kwargs = kwargs | |
| @property | |
| def indexing_func(self) -> tp.Callable: | |
| """Indexing function.""" | |
| return self._indexing_func | |
| @property | |
| def indexing_setter_func(self) -> tp.Optional[tp.Callable]: | |
| """Indexing setter function.""" | |
| return self._indexing_setter_func | |
| @property | |
| def indexing_kwargs(self) -> dict: | |
| """Keyword arguments passed to `LocBase.indexing_func`.""" | |
| return self._indexing_kwargs | |
| def __getitem__(self, key: tp.Any) -> tp.Any: | |
| raise NotImplementedError | |
| def __setitem__(self, key: tp.Any, value: tp.Any) -> None: | |
| raise NotImplementedError | |
| def __iter__(self): | |
| raise TypeError(f"'{type(self).__name__}' object is not iterable") | |
| class pdLoc(LocBase): | |
| """Forwards a Pandas-like indexing operation to each Series/DataFrame and returns a new class instance.""" | |
| @classmethod | |
| def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: | |
| """Pandas-like indexing operation.""" | |
| raise NotImplementedError | |
| @classmethod | |
| def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: | |
| """Pandas-like indexing setter operation.""" | |
| raise NotImplementedError | |
| def __getitem__(self, key: tp.Any) -> tp.Any: | |
| return self.indexing_func(partial(self.pd_indexing_func, key=key), **self.indexing_kwargs) | |
| def __setitem__(self, key: tp.Any, value: tp.Any) -> None: | |
| self.indexing_setter_func(partial(self.pd_indexing_setter_func, key=key, value=value), **self.indexing_kwargs) | |
| class iLoc(pdLoc): | |
| """Forwards `pd.Series.iloc`/`pd.DataFrame.iloc` operation to each | |
| Series/DataFrame and returns a new class instance.""" | |
| @classmethod | |
| def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: | |
| return obj.iloc.__getitem__(key) | |
| @classmethod | |
| def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: | |
| obj.iloc.__setitem__(key, value) | |
| class Loc(pdLoc): | |
| """Forwards `pd.Series.loc`/`pd.DataFrame.loc` operation to each | |
| Series/DataFrame and returns a new class instance.""" | |
| @classmethod | |
| def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: | |
| return obj.loc.__getitem__(key) | |
| @classmethod | |
| def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: | |
| obj.loc.__setitem__(key, value) | |
| PandasIndexerT = tp.TypeVar("PandasIndexerT", bound="PandasIndexer") | |
| class PandasIndexer(IndexingBase): | |
| """Implements indexing using `iloc`, `loc`, `xs` and `__getitem__`. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.indexing import PandasIndexer | |
| >>> class C(PandasIndexer): | |
| ... def __init__(self, df1, df2): | |
| ... self.df1 = df1 | |
| ... self.df2 = df2 | |
| ... super().__init__() | |
| ... | |
| ... def indexing_func(self, pd_indexing_func): | |
| ... return type(self)( | |
| ... pd_indexing_func(self.df1), | |
| ... pd_indexing_func(self.df2) | |
| ... ) | |
| >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) | |
| >>> df2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]}) | |
| >>> c = C(df1, df2) | |
| >>> c.iloc[:, 0] | |
| <__main__.C object at 0x1a1cacbbe0> | |
| >>> c.iloc[:, 0].df1 | |
| 0 1 | |
| 1 2 | |
| Name: a, dtype: int64 | |
| >>> c.iloc[:, 0].df2 | |
| 0 5 | |
| 1 6 | |
| Name: a, dtype: int64 | |
| ``` | |
| """ | |
| def __init__(self, **kwargs) -> None: | |
| self._iloc = iLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs) | |
| self._loc = Loc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs) | |
| self._indexing_kwargs = kwargs | |
| @property | |
| def indexing_kwargs(self) -> dict: | |
| """Indexing keyword arguments.""" | |
| return self._indexing_kwargs | |
| @property | |
| def iloc(self) -> iLoc: | |
| """Purely integer-location based indexing for selection by position.""" | |
| return self._iloc | |
| iloc.__doc__ = iLoc.__doc__ | |
| @property | |
| def loc(self) -> Loc: | |
| """Purely label-location based indexer for selection by label.""" | |
| return self._loc | |
| loc.__doc__ = Loc.__doc__ | |
| def xs(self: PandasIndexerT, *args, **kwargs) -> PandasIndexerT: | |
| """Forwards `pd.Series.xs`/`pd.DataFrame.xs` | |
| operation to each Series/DataFrame and returns a new class instance.""" | |
| return self.indexing_func(lambda x: x.xs(*args, **kwargs), **self.indexing_kwargs) | |
| def __getitem__(self: PandasIndexerT, key: tp.Any) -> PandasIndexerT: | |
| def __getitem__func(x, _key=key): | |
| return x.__getitem__(_key) | |
| return self.indexing_func(__getitem__func, **self.indexing_kwargs) | |
| def __setitem__(self, key: tp.Any, value: tp.Any) -> None: | |
| def __setitem__func(x, _key=key, _value=value): | |
| return x.__setitem__(_key, _value) | |
| self.indexing_setter_func(__setitem__func, **self.indexing_kwargs) | |
| def __iter__(self): | |
| raise TypeError(f"'{type(self).__name__}' object is not iterable") | |
| class xLoc(iLoc): | |
| """Subclass of `iLoc` that transforms an `Idxr`-based operation with | |
| `get_idxs` to an `iLoc` operation.""" | |
| @classmethod | |
| def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: | |
| from vectorbtpro.base.indexes import get_index | |
| if isinstance(key, tuple): | |
| key = Idxr(*key) | |
| index = get_index(obj, 0) | |
| columns = get_index(obj, 1) | |
| freq = dt.infer_index_freq(index) | |
| row_idxs, col_idxs = get_idxs(key, index=index, columns=columns, freq=freq) | |
| if isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2: | |
| row_idxs = normalize_idxs(row_idxs, target_len=len(index)) | |
| if isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2: | |
| col_idxs = normalize_idxs(col_idxs, target_len=len(columns)) | |
| if isinstance(obj, pd.Series): | |
| if not isinstance(col_idxs, (slice, hslice)) or ( | |
| col_idxs.start is not None or col_idxs.stop is not None or col_idxs.step is not None | |
| ): | |
| raise IndexingError("Too many indexers") | |
| return obj.iloc.__getitem__(row_idxs) | |
| return obj.iloc.__getitem__((row_idxs, col_idxs)) | |
| @classmethod | |
| def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: | |
| IdxSetter([(key, value)]).set_pd(obj) | |
| class ExtPandasIndexer(PandasIndexer): | |
| """Extension of `PandasIndexer` that also implements indexing using `xLoc`.""" | |
| def __init__(self, **kwargs) -> None: | |
| self._xloc = xLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs) | |
| PandasIndexer.__init__(self, **kwargs) | |
| @property | |
| def xloc(self) -> xLoc: | |
| """`Idxr`-based indexing.""" | |
| return self._xloc | |
| xloc.__doc__ = xLoc.__doc__ | |
| class ParamLoc(LocBase): | |
| """Access a group of columns by parameter using `pd.Series.loc`. | |
| Uses `mapper` to establish link between columns and parameter values.""" | |
| @classmethod | |
| def encode_key(cls, key: tp.Any): | |
| """Encode key.""" | |
| if isinstance(key, tuple): | |
| return str(tuple(map(lambda k: k.item() if isinstance(k, np.generic) else k, key))) | |
| key_str = str(key) | |
| return str(key.item()) if isinstance(key, np.generic) else key_str | |
| def __init__( | |
| self, | |
| mapper: tp.Series, | |
| indexing_func: tp.Callable, | |
| indexing_setter_func: tp.Optional[tp.Callable] = None, | |
| level_name: tp.Level = None, | |
| **kwargs, | |
| ) -> None: | |
| checks.assert_instance_of(mapper, pd.Series) | |
| if mapper.dtype == "O": | |
| if isinstance(mapper.iloc[0], tuple): | |
| mapper = mapper.apply(self.encode_key) | |
| else: | |
| mapper = mapper.astype(str) | |
| self._mapper = mapper | |
| self._level_name = level_name | |
| LocBase.__init__(self, indexing_func, indexing_setter_func=indexing_setter_func, **kwargs) | |
| @property | |
| def mapper(self) -> tp.Series: | |
| """Mapper.""" | |
| return self._mapper | |
| @property | |
| def level_name(self) -> tp.Level: | |
| """Level name.""" | |
| return self._level_name | |
| def get_idxs(self, key: tp.Any) -> tp.Array1d: | |
| """Get array of indices affected by this key.""" | |
| if self.mapper.dtype == "O": | |
| if isinstance(key, (slice, hslice)): | |
| start = self.encode_key(key.start) if key.start is not None else None | |
| stop = self.encode_key(key.stop) if key.stop is not None else None | |
| key = slice(start, stop, key.step) | |
| elif isinstance(key, (list, np.ndarray)): | |
| key = list(map(self.encode_key, key)) | |
| else: | |
| key = self.encode_key(key) | |
| mapper = pd.Series(np.arange(len(self.mapper.index)), index=self.mapper.values) | |
| idxs = mapper.loc.__getitem__(key) | |
| if isinstance(idxs, pd.Series): | |
| idxs = idxs.values | |
| return idxs | |
| def __getitem__(self, key: tp.Any) -> tp.Any: | |
| idxs = self.get_idxs(key) | |
| is_multiple = isinstance(key, (slice, hslice, list, np.ndarray)) | |
| def pd_indexing_func(obj: tp.SeriesFrame) -> tp.MaybeSeriesFrame: | |
| from vectorbtpro.base.indexes import drop_levels | |
| new_obj = obj.iloc[:, idxs] | |
| if not is_multiple: | |
| if self.level_name is not None: | |
| if checks.is_frame(new_obj): | |
| if isinstance(new_obj.columns, pd.MultiIndex): | |
| new_obj.columns = drop_levels(new_obj.columns, self.level_name) | |
| return new_obj | |
| return self.indexing_func(pd_indexing_func, **self.indexing_kwargs) | |
| def __setitem__(self, key: tp.Any, value: tp.Any) -> None: | |
| idxs = self.get_idxs(key) | |
| def pd_indexing_setter_func(obj: tp.SeriesFrame) -> None: | |
| obj.iloc[:, idxs] = value | |
| return self.indexing_setter_func(pd_indexing_setter_func, **self.indexing_kwargs) | |
| def indexing_on_mapper( | |
| mapper: tp.Series, | |
| ref_obj: tp.SeriesFrame, | |
| pd_indexing_func: tp.Callable, | |
| ) -> tp.Optional[tp.Series]: | |
| """Broadcast `mapper` Series to `ref_obj` and perform pandas indexing using `pd_indexing_func`.""" | |
| from vectorbtpro.base.reshaping import broadcast_to | |
| checks.assert_instance_of(mapper, pd.Series) | |
| checks.assert_instance_of(ref_obj, (pd.Series, pd.DataFrame)) | |
| if isinstance(ref_obj, pd.Series): | |
| range_mapper = broadcast_to(0, ref_obj) | |
| else: | |
| range_mapper = broadcast_to(np.arange(len(mapper.index))[None], ref_obj) | |
| loced_range_mapper = pd_indexing_func(range_mapper) | |
| new_mapper = mapper.iloc[loced_range_mapper.values[0]] | |
| if checks.is_frame(loced_range_mapper): | |
| return pd.Series(new_mapper.values, index=loced_range_mapper.columns, name=mapper.name) | |
| elif checks.is_series(loced_range_mapper): | |
| return pd.Series([new_mapper], index=[loced_range_mapper.name], name=mapper.name) | |
| return None | |
| def build_param_indexer( | |
| param_names: tp.Sequence[str], | |
| class_name: str = "ParamIndexer", | |
| module_name: tp.Optional[str] = None, | |
| ) -> tp.Type[IndexingBase]: | |
| """A factory to create a class with parameter indexing. | |
| Parameter indexer enables accessing a group of rows and columns by a parameter array (similar to `loc`). | |
| This way, one can query index/columns by another Series called a parameter mapper, which is just a | |
| `pd.Series` that maps columns (its index) to params (its values). | |
| Parameter indexing is important, since querying by column/index labels alone is not always the best option. | |
| For example, `pandas` doesn't let you query by list at a specific index/column level. | |
| Args: | |
| param_names (list of str): Names of the parameters. | |
| class_name (str): Name of the generated class. | |
| module_name (str): Name of the module to which the class should be bound. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.indexing import build_param_indexer, indexing_on_mapper | |
| >>> MyParamIndexer = build_param_indexer(['my_param']) | |
| >>> class C(MyParamIndexer): | |
| ... def __init__(self, df, param_mapper): | |
| ... self.df = df | |
| ... self._my_param_mapper = param_mapper | |
| ... super().__init__([param_mapper]) | |
| ... | |
| ... def indexing_func(self, pd_indexing_func): | |
| ... return type(self)( | |
| ... pd_indexing_func(self.df), | |
| ... indexing_on_mapper(self._my_param_mapper, self.df, pd_indexing_func) | |
| ... ) | |
| >>> df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) | |
| >>> param_mapper = pd.Series(['First', 'Second'], index=['a', 'b']) | |
| >>> c = C(df, param_mapper) | |
| >>> c.my_param_loc['First'].df | |
| 0 1 | |
| 1 2 | |
| Name: a, dtype: int64 | |
| >>> c.my_param_loc['Second'].df | |
| 0 3 | |
| 1 4 | |
| Name: b, dtype: int64 | |
| >>> c.my_param_loc[['First', 'First', 'Second', 'Second']].df | |
| a b | |
| 0 1 1 3 3 | |
| 1 2 2 4 4 | |
| ``` | |
| """ | |
| class ParamIndexer(IndexingBase): | |
| """Class with parameter indexing.""" | |
| def __init__( | |
| self, | |
| param_mappers: tp.Sequence[tp.Series], | |
| level_names: tp.Optional[tp.LevelSequence] = None, | |
| **kwargs, | |
| ) -> None: | |
| checks.assert_len_equal(param_names, param_mappers) | |
| for i, param_name in enumerate(param_names): | |
| level_name = level_names[i] if level_names is not None else None | |
| _param_loc = ParamLoc(param_mappers[i], self.indexing_func, level_name=level_name, **kwargs) | |
| setattr(self, f"_{param_name}_loc", _param_loc) | |
| for i, param_name in enumerate(param_names): | |
| def param_loc(self, _param_name=param_name) -> ParamLoc: | |
| return getattr(self, f"_{_param_name}_loc") | |
| param_loc.__doc__ = f"""Access a group of columns by parameter `{param_name}` using `pd.Series.loc`. | |
| Forwards this operation to each Series/DataFrame and returns a new class instance. | |
| """ | |
| setattr(ParamIndexer, param_name + "_loc", property(param_loc)) | |
| ParamIndexer.__name__ = class_name | |
| ParamIndexer.__qualname__ = ParamIndexer.__name__ | |
| if module_name is not None: | |
| ParamIndexer.__module__ = module_name | |
| return ParamIndexer | |
| hsliceT = tp.TypeVar("hsliceT", bound="hslice") | |
| @define | |
| class hslice(DefineMixin): | |
| """Hashable slice.""" | |
| start: object = define.field() | |
| """Start.""" | |
| stop: object = define.field() | |
| """Stop.""" | |
| step: object = define.field() | |
| """Step.""" | |
| def __init__(self, start: object = MISSING, stop: object = MISSING, step: object = MISSING) -> None: | |
| if start is not MISSING and stop is MISSING and step is MISSING: | |
| stop = start | |
| start, step = None, None | |
| else: | |
| if start is MISSING: | |
| start = None | |
| if stop is MISSING: | |
| stop = None | |
| if step is MISSING: | |
| step = None | |
| DefineMixin.__init__(self, start=start, stop=stop, step=step) | |
| @classmethod | |
| def from_slice(cls: tp.Type[hsliceT], slice_: slice) -> hsliceT: | |
| """Construct from a slice.""" | |
| return cls(slice_.start, slice_.stop, slice_.step) | |
| def to_slice(self) -> slice: | |
| """Convert to a slice.""" | |
| return slice(self.start, self.stop, self.step) | |
| class IdxrBase(Base): | |
| """Abstract class for resolving indices.""" | |
| def get(self, *args, **kwargs) -> tp.Any: | |
| """Get indices.""" | |
| raise NotImplementedError | |
| @classmethod | |
| def slice_indexer( | |
| cls, | |
| index: tp.Index, | |
| slice_: tp.Slice, | |
| closed_start: bool = True, | |
| closed_end: bool = False, | |
| ) -> slice: | |
| """Compute the slice indexer for input labels and step.""" | |
| start = slice_.start | |
| end = slice_.stop | |
| if start is not None: | |
| left_start = index.get_slice_bound(start, side="left") | |
| right_start = index.get_slice_bound(start, side="right") | |
| if left_start == right_start or not closed_start: | |
| start = right_start | |
| else: | |
| start = left_start | |
| if end is not None: | |
| left_end = index.get_slice_bound(end, side="left") | |
| right_end = index.get_slice_bound(end, side="right") | |
| if left_end == right_end or closed_end: | |
| end = right_end | |
| else: | |
| end = left_end | |
| return slice(start, end, slice_.step) | |
| def check_idxs(self, idxs: tp.MaybeIndexArray, check_minus_one: bool = False) -> None: | |
| """Check indices after resolving them.""" | |
| if isinstance(idxs, slice): | |
| if idxs.start is not None and not checks.is_int(idxs.start): | |
| raise TypeError("Start of a returned index slice must be an integer or None") | |
| if idxs.stop is not None and not checks.is_int(idxs.stop): | |
| raise TypeError("Stop of a returned index slice must be an integer or None") | |
| if idxs.step is not None and not checks.is_int(idxs.step): | |
| raise TypeError("Step of a returned index slice must be an integer or None") | |
| if check_minus_one and idxs.start == -1: | |
| raise ValueError("Range start index couldn't be matched") | |
| elif check_minus_one and idxs.stop == -1: | |
| raise ValueError("Range end index couldn't be matched") | |
| elif checks.is_int(idxs): | |
| if check_minus_one and idxs == -1: | |
| raise ValueError("Index couldn't be matched") | |
| elif checks.is_sequence(idxs) and not np.isscalar(idxs): | |
| if len(idxs) == 0: | |
| raise ValueError("No indices could be matched") | |
| if not isinstance(idxs, np.ndarray): | |
| raise ValueError(f"Indices must be a NumPy array, not {type(idxs)}") | |
| if not np.issubdtype(idxs.dtype, np.integer) or np.issubdtype(idxs.dtype, np.bool_): | |
| raise ValueError(f"Indices must be of integer data type, not {idxs.dtype}") | |
| if check_minus_one and -1 in idxs: | |
| raise ValueError("Some indices couldn't be matched") | |
| if idxs.ndim not in (1, 2): | |
| raise ValueError("Indices array must have either 1 or 2 dimensions") | |
| if idxs.ndim == 2 and idxs.shape[1] != 2: | |
| raise ValueError("Indices array provided as ranges must have exactly two columns") | |
| else: | |
| raise TypeError( | |
| f"Indices must be an integer, a slice, a NumPy array, or a tuple of two NumPy arrays, not {type(idxs)}" | |
| ) | |
| def normalize_idxs(idxs: tp.MaybeIndexArray, target_len: int) -> tp.Array1d: | |
| """Normalize indexes into a 1-dim integer array.""" | |
| if isinstance(idxs, hslice): | |
| idxs = idxs.to_slice() | |
| if isinstance(idxs, slice): | |
| idxs = np.arange(target_len)[idxs] | |
| if checks.is_int(idxs): | |
| idxs = np.array([idxs]) | |
| if idxs.ndim == 2: | |
| from vectorbtpro.base.merging import concat_arrays | |
| idxs = concat_arrays(tuple(map(lambda x: np.arange(x[0], x[1]), idxs))) | |
| if (idxs < 0).any(): | |
| idxs = np.where(idxs >= 0, idxs, target_len + idxs) | |
| return idxs | |
| class UniIdxr(IdxrBase): | |
| """Abstract class for resolving indices based on a single index.""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| raise NotImplementedError | |
| def __invert__(self): | |
| def _op_func(x, index=None, freq=None): | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| idxs = np.setdiff1d(np.arange(len(index)), x) | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self) | |
| def __and__(self, other): | |
| def _op_func(x, y, index=None, freq=None): | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| y = normalize_idxs(y, len(index)) | |
| idxs = np.intersect1d(x, y) | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self, other) | |
| def __or__(self, other): | |
| def _op_func(x, y, index=None, freq=None): | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| y = normalize_idxs(y, len(index)) | |
| idxs = np.union1d(x, y) | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self, other) | |
| def __sub__(self, other): | |
| def _op_func(x, y, index=None, freq=None): | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| y = normalize_idxs(y, len(index)) | |
| idxs = np.setdiff1d(x, y) | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self, other) | |
| def __xor__(self, other): | |
| def _op_func(x, y, index=None, freq=None): | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| y = normalize_idxs(y, len(index)) | |
| idxs = np.setxor1d(x, y) | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self, other) | |
| def __lshift__(self, other): | |
| def _op_func(x, y, index=None, freq=None): | |
| if not checks.is_int(y): | |
| raise TypeError("Second operand in __lshift__ must be an integer") | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| shifted = x - y | |
| idxs = shifted[shifted >= 0] | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self, other) | |
| def __rshift__(self, other): | |
| def _op_func(x, y, index=None, freq=None): | |
| if not checks.is_int(y): | |
| raise TypeError("Second operand in __rshift__ must be an integer") | |
| if index is None: | |
| raise ValueError("Index is required") | |
| x = normalize_idxs(x, len(index)) | |
| shifted = x + y | |
| idxs = shifted[shifted >= 0] | |
| self.check_idxs(idxs) | |
| return idxs | |
| return UniIdxrOp(_op_func, self, other) | |
| @define | |
| class UniIdxrOp(UniIdxr, DefineMixin): | |
| """Class for applying an operation to one or more indexers. | |
| Produces a single set of indices.""" | |
| op_func: tp.Callable = define.field() | |
| """Operation function that takes the indices of each indexer (as `*args`), `index` (keyword argument), | |
| and `freq` (keyword argument), and returns new indices.""" | |
| idxrs: tp.Tuple[object, ...] = define.field() | |
| """A tuple of one or more indexers.""" | |
| def __init__(self, op_func: tp.Callable, *idxrs) -> None: | |
| if len(idxrs) == 1 and checks.is_iterable(idxrs[0]): | |
| idxrs = idxrs[0] | |
| DefineMixin.__init__(self, op_func=op_func, idxrs=idxrs) | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| idxr_indices = [] | |
| for idxr in self.idxrs: | |
| if isinstance(idxr, IdxrBase): | |
| checks.assert_instance_of(idxr, UniIdxr) | |
| idxr_indices.append(idxr.get(index=index, freq=freq)) | |
| else: | |
| idxr_indices.append(idxr) | |
| return self.op_func(*idxr_indices, index=index, freq=freq) | |
| @define | |
| class PosIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving indices provided as integer positions.""" | |
| value: tp.Union[None, tp.MaybeSequence[tp.MaybeSequence[int]], tp.Slice] = define.field() | |
| """One or more integer positions.""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if self.value is None: | |
| return slice(None, None, None) | |
| idxs = self.value | |
| if checks.is_sequence(idxs) and not np.isscalar(idxs): | |
| idxs = np.asarray(idxs) | |
| if isinstance(idxs, hslice): | |
| idxs = idxs.to_slice() | |
| self.check_idxs(idxs) | |
| return idxs | |
| @define | |
| class MaskIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving indices provided as a mask.""" | |
| value: tp.Union[None, tp.Sequence[bool]] = define.field() | |
| """Mask.""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if self.value is None: | |
| return slice(None, None, None) | |
| idxs = np.flatnonzero(self.value) | |
| self.check_idxs(idxs) | |
| return idxs | |
| @define | |
| class LabelIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving indices provided as labels.""" | |
| value: tp.Union[None, tp.MaybeSequence[tp.Label], tp.Slice] = define.field() | |
| """One or more labels.""" | |
| closed_start: bool = define.field(default=True) | |
| """Whether slice start should be inclusive.""" | |
| closed_end: bool = define.field(default=True) | |
| """Whether slice end should be inclusive.""" | |
| level: tp.MaybeLevelSequence = define.field(default=None) | |
| """One or more levels.""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if self.value is None: | |
| return slice(None, None, None) | |
| if index is None: | |
| raise ValueError("Index is required") | |
| if self.level is not None: | |
| from vectorbtpro.base.indexes import select_levels | |
| index = select_levels(index, self.level) | |
| if isinstance(self.value, (slice, hslice)): | |
| idxs = self.slice_indexer( | |
| index, | |
| self.value, | |
| closed_start=self.closed_start, | |
| closed_end=self.closed_end, | |
| ) | |
| elif (checks.is_sequence(self.value) and not np.isscalar(self.value)) and ( | |
| not isinstance(index, pd.MultiIndex) | |
| or (isinstance(index, pd.MultiIndex) and isinstance(self.value[0], tuple)) | |
| ): | |
| idxs = index.get_indexer_for(self.value) | |
| else: | |
| idxs = index.get_loc(self.value) | |
| if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_): | |
| idxs = np.flatnonzero(idxs) | |
| self.check_idxs(idxs, check_minus_one=True) | |
| return idxs | |
| @define | |
| class DatetimeIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving indices provided as datetime-like objects.""" | |
| value: tp.Union[None, tp.MaybeSequence[tp.DatetimeLike], tp.Slice] = define.field() | |
| """One or more datetime-like objects.""" | |
| closed_start: bool = define.field(default=True) | |
| """Whether slice start should be inclusive.""" | |
| closed_end: bool = define.field(default=False) | |
| """Whether slice end should be inclusive.""" | |
| indexer_method: tp.Optional[str] = define.field(default="bfill") | |
| """Method for `pd.Index.get_indexer`. | |
| Allows two additional values: "before" and "after".""" | |
| below_to_zero: bool = define.field(default=False) | |
| """Whether to place 0 instead of -1 if `DatetimeIdxr.value` is below the first index.""" | |
| above_to_len: bool = define.field(default=False) | |
| """Whether to place `len(index)` instead of -1 if `DatetimeIdxr.value` is above the last index.""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if self.value is None: | |
| return slice(None, None, None) | |
| if index is None: | |
| raise ValueError("Index is required") | |
| index = dt.prepare_dt_index(index) | |
| checks.assert_instance_of(index, pd.DatetimeIndex) | |
| if not index.is_unique: | |
| raise ValueError("Datetime index must be unique") | |
| if not index.is_monotonic_increasing: | |
| raise ValueError("Datetime index must be monotonically increasing") | |
| if isinstance(self.value, (slice, hslice)): | |
| start = dt.try_align_dt_to_index(self.value.start, index) | |
| stop = dt.try_align_dt_to_index(self.value.stop, index) | |
| new_value = slice(start, stop, self.value.step) | |
| idxs = self.slice_indexer(index, new_value, closed_start=self.closed_start, closed_end=self.closed_end) | |
| elif checks.is_sequence(self.value) and not np.isscalar(self.value): | |
| new_value = dt.try_align_to_dt_index(self.value, index) | |
| idxs = index.get_indexer(new_value, method=self.indexer_method) | |
| if self.below_to_zero: | |
| idxs = np.where(new_value < index[0], 0, idxs) | |
| if self.above_to_len: | |
| idxs = np.where(new_value > index[-1], len(index), idxs) | |
| else: | |
| new_value = dt.try_align_dt_to_index(self.value, index) | |
| if new_value < index[0] and self.below_to_zero: | |
| idxs = 0 | |
| elif new_value > index[-1] and self.above_to_len: | |
| idxs = len(index) | |
| else: | |
| if self.indexer_method is None or new_value in index: | |
| idxs = index.get_loc(new_value) | |
| if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_): | |
| idxs = np.flatnonzero(idxs) | |
| else: | |
| indexer_method = self.indexer_method | |
| if indexer_method is not None: | |
| indexer_method = indexer_method.lower() | |
| if indexer_method == "before": | |
| new_value = new_value - pd.Timedelta(1, "ns") | |
| indexer_method = "ffill" | |
| elif indexer_method == "after": | |
| new_value = new_value + pd.Timedelta(1, "ns") | |
| indexer_method = "bfill" | |
| idxs = index.get_indexer([new_value], method=indexer_method)[0] | |
| self.check_idxs(idxs, check_minus_one=True) | |
| return idxs | |
| @define | |
| class DTCIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving indices provided as datetime-like components.""" | |
| value: tp.Union[None, tp.MaybeSequence[tp.DTCLike], tp.Slice] = define.field() | |
| """One or more datetime-like components.""" | |
| parse_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to `vectorbtpro.utils.datetime_.DTC.parse`.""" | |
| closed_start: bool = define.field(default=True) | |
| """Whether slice start should be inclusive.""" | |
| closed_end: bool = define.field(default=False) | |
| """Whether slice end should be inclusive.""" | |
| jitted: tp.JittedOption = define.field(default=None) | |
| """Jitting option passed to `vectorbtpro.utils.datetime_nb.index_matches_dtc_nb` | |
| and `vectorbtpro.utils.datetime_nb.index_within_dtc_range_nb`.""" | |
| @staticmethod | |
| def get_dtc_namedtuple(value: tp.Optional[tp.DTCLike] = None, **parse_kwargs) -> dt.DTCNT: | |
| """Convert a value to a `vectorbtpro.utils.datetime_.DTCNT` instance.""" | |
| if value is None: | |
| return dt.DTC().to_namedtuple() | |
| if isinstance(value, dt.DTC): | |
| return value.to_namedtuple() | |
| if isinstance(value, dt.DTCNT): | |
| return value | |
| return dt.DTC.parse(value, **parse_kwargs).to_namedtuple() | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if self.value is None: | |
| return slice(None, None, None) | |
| parse_kwargs = self.parse_kwargs | |
| if parse_kwargs is None: | |
| parse_kwargs = {} | |
| if index is None: | |
| raise ValueError("Index is required") | |
| index = dt.prepare_dt_index(index) | |
| ns_index = dt.to_ns(index) | |
| checks.assert_instance_of(index, pd.DatetimeIndex) | |
| if not index.is_unique: | |
| raise ValueError("Datetime index must be unique") | |
| if not index.is_monotonic_increasing: | |
| raise ValueError("Datetime index must be monotonically increasing") | |
| if isinstance(self.value, (slice, hslice)): | |
| if self.value.step is not None: | |
| raise ValueError("Step must be None") | |
| if self.value.start is None and self.value.stop is None: | |
| return slice(None, None, None) | |
| start_dtc = self.get_dtc_namedtuple(self.value.start, **parse_kwargs) | |
| end_dtc = self.get_dtc_namedtuple(self.value.stop, **parse_kwargs) | |
| func = jit_reg.resolve_option(dt_nb.index_within_dtc_range_nb, self.jitted) | |
| mask = func(ns_index, start_dtc, end_dtc, closed_start=self.closed_start, closed_end=self.closed_end) | |
| elif checks.is_sequence(self.value) and not np.isscalar(self.value): | |
| func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted) | |
| dtcs = map(lambda x: self.get_dtc_namedtuple(x, **parse_kwargs), self.value) | |
| masks = map(lambda x: func(ns_index, x), dtcs) | |
| mask = functools.reduce(np.logical_or, masks) | |
| else: | |
| dtc = self.get_dtc_namedtuple(self.value, **parse_kwargs) | |
| func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted) | |
| mask = func(ns_index, dtc) | |
| return MaskIdxr(mask).get(index=index, freq=freq) | |
| @define | |
| class PointIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving index points.""" | |
| every: tp.Optional[tp.FrequencyLike] = define.field(default=None) | |
| """Frequency either as an integer or timedelta. | |
| Gets translated into `on` array by creating a range. If integer, an index sequence from `start` to `end` | |
| (exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence from | |
| `start` to `end` (inclusive) is created and 'labels' as `kind` is used. | |
| If `at_time` is not None and `every` and `on` are None, `every` defaults to one day.""" | |
| normalize_every: bool = define.field(default=False) | |
| """Normalize start/end dates to midnight before generating date range.""" | |
| at_time: tp.Optional[tp.TimeLike] = define.field(default=None) | |
| """Time of the day either as a (human-readable) string or `datetime.time`. | |
| Every datetime in `on` gets floored to the daily frequency, while `at_time` gets converted into | |
| a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_delta`. | |
| Index must be datetime-like.""" | |
| start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None) | |
| """Start index/date. | |
| If (human-readable) string, gets converted into a datetime. | |
| If `every` is None, gets used to filter the final index array.""" | |
| end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None) | |
| """End index/date. | |
| If (human-readable) string, gets converted into a datetime. | |
| If `every` is None, gets used to filter the final index array.""" | |
| exact_start: bool = define.field(default=False) | |
| """Whether the first index should be exactly `start`. | |
| Depending on `every`, the first index picked by `pd.date_range` may happen after `start`. | |
| In such a case, `start` gets injected before the first index generated by `pd.date_range`.""" | |
| on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None) | |
| """Index/label or a sequence of such. | |
| Gets converted into datetime format whenever possible.""" | |
| add_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None) | |
| """Offset to be added to each in `on`. | |
| Gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.""" | |
| kind: tp.Optional[str] = define.field(default=None) | |
| """Kind of data in `on`: indices or labels. | |
| If None, gets assigned to `indices` if `on` contains integer data, otherwise to `labels`. | |
| If `kind` is 'labels', `on` gets converted into indices using `pd.Index.get_indexer`. | |
| Prior to this, gets its timezone aligned to the timezone of the index. If `kind` is 'indices', | |
| `on` gets wrapped with NumPy.""" | |
| indexer_method: str = define.field(default="bfill") | |
| """Method for `pd.Index.get_indexer`. | |
| Allows two additional values: "before" and "after".""" | |
| indexer_tolerance: tp.Optional[tp.Union[int, tp.TimedeltaLike, tp.IndexLike]] = define.field(default=None) | |
| """Tolerance for `pd.Index.get_indexer`. | |
| If `at_time` is set and `indexer_method` is neither exact nor nearest, `indexer_tolerance` | |
| becomes such that the next element must be within the current day.""" | |
| skip_not_found: bool = define.field(default=True) | |
| """Whether to drop indices that are -1 (not found).""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if index is None: | |
| raise ValueError("Index is required") | |
| idxs = get_index_points(index, **self.asdict()) | |
| self.check_idxs(idxs, check_minus_one=True) | |
| return idxs | |
| point_idxr_defaults = {a.name: a.default for a in PointIdxr.fields} | |
| def get_index_points( | |
| index: tp.Index, | |
| every: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["every"], | |
| normalize_every: bool = point_idxr_defaults["normalize_every"], | |
| at_time: tp.Optional[tp.TimeLike] = point_idxr_defaults["at_time"], | |
| start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["start"], | |
| end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["end"], | |
| exact_start: bool = point_idxr_defaults["exact_start"], | |
| on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = point_idxr_defaults["on"], | |
| add_delta: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["add_delta"], | |
| kind: tp.Optional[str] = point_idxr_defaults["kind"], | |
| indexer_method: str = point_idxr_defaults["indexer_method"], | |
| indexer_tolerance: str = point_idxr_defaults["indexer_tolerance"], | |
| skip_not_found: bool = point_idxr_defaults["skip_not_found"], | |
| ) -> tp.Array1d: | |
| """Translate indices or labels into index points. | |
| See `PointIdxr` for arguments. | |
| Usage: | |
| * Provide nothing to generate at the beginning: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> index = pd.date_range("2020-01", "2020-02", freq="1d") | |
| >>> vbt.get_index_points(index) | |
| array([0]) | |
| ``` | |
| * Provide `every` as an integer frequency to generate index points using NumPy: | |
| ```pycon | |
| >>> # Generate a point every five rows | |
| >>> vbt.get_index_points(index, every=5) | |
| array([ 0, 5, 10, 15, 20, 25, 30]) | |
| >>> # Generate a point every five rows starting at 6th row | |
| >>> vbt.get_index_points(index, every=5, start=5) | |
| array([ 5, 10, 15, 20, 25, 30]) | |
| >>> # Generate a point every five rows from 6th to 16th row | |
| >>> vbt.get_index_points(index, every=5, start=5, end=15) | |
| array([ 5, 10]) | |
| ``` | |
| * Provide `every` as a time delta frequency to generate index points using Pandas: | |
| ```pycon | |
| >>> # Generate a point every week | |
| >>> vbt.get_index_points(index, every="W") | |
| array([ 4, 11, 18, 25]) | |
| >>> # Generate a point every second day of the week | |
| >>> vbt.get_index_points(index, every="W", add_delta="2d") | |
| array([ 6, 13, 20, 27]) | |
| >>> # Generate a point every week, starting at 11th row | |
| >>> vbt.get_index_points(index, every="W", start=10) | |
| array([11, 18, 25]) | |
| >>> # Generate a point every week, starting exactly at 11th row | |
| >>> vbt.get_index_points(index, every="W", start=10, exact_start=True) | |
| array([10, 11, 18, 25]) | |
| >>> # Generate a point every week, starting at 2020-01-10 | |
| >>> vbt.get_index_points(index, every="W", start="2020-01-10") | |
| array([11, 18, 25]) | |
| ``` | |
| * Instead of using `every`, provide indices explicitly: | |
| ```pycon | |
| >>> # Generate one point | |
| >>> vbt.get_index_points(index, on="2020-01-07") | |
| array([6]) | |
| >>> # Generate multiple points | |
| >>> vbt.get_index_points(index, on=["2020-01-07", "2020-01-14"]) | |
| array([ 6, 13]) | |
| ``` | |
| """ | |
| index = dt.prepare_dt_index(index) | |
| if on is not None and isinstance(on, str): | |
| on = dt.try_align_dt_to_index(on, index) | |
| if start is not None and isinstance(start, str): | |
| start = dt.try_align_dt_to_index(start, index) | |
| if end is not None and isinstance(end, str): | |
| end = dt.try_align_dt_to_index(end, index) | |
| if every is not None and not checks.is_int(every): | |
| every = dt.to_freq(every) | |
| start_used = False | |
| end_used = False | |
| if at_time is not None and every is None and on is None: | |
| every = pd.Timedelta(days=1) | |
| if every is not None: | |
| start_used = True | |
| end_used = True | |
| if checks.is_int(every): | |
| if start is None: | |
| start = 0 | |
| if end is None: | |
| end = len(index) | |
| on = np.arange(start, end, every) | |
| kind = "indices" | |
| else: | |
| if start is None: | |
| start = 0 | |
| if checks.is_int(start): | |
| start_date = index[start] | |
| else: | |
| start_date = start | |
| if end is None: | |
| end = len(index) - 1 | |
| if checks.is_int(end): | |
| end_date = index[end] | |
| else: | |
| end_date = end | |
| on = dt.date_range( | |
| start_date, | |
| end_date, | |
| freq=every, | |
| tz=index.tz, | |
| normalize=normalize_every, | |
| inclusive="both", | |
| ) | |
| if exact_start and on[0] > start_date: | |
| on = on.insert(0, start_date) | |
| kind = "labels" | |
| if kind is None: | |
| if on is None: | |
| if start is not None: | |
| if checks.is_int(start): | |
| kind = "indices" | |
| else: | |
| kind = "labels" | |
| else: | |
| kind = "indices" | |
| else: | |
| on = dt.prepare_dt_index(on) | |
| if pd.api.types.is_integer_dtype(on): | |
| kind = "indices" | |
| else: | |
| kind = "labels" | |
| checks.assert_in(kind, ("indices", "labels")) | |
| if on is None: | |
| if start is not None: | |
| on = start | |
| start_used = True | |
| else: | |
| if kind.lower() in ("labels",): | |
| on = index | |
| else: | |
| on = np.arange(len(index)) | |
| on = dt.prepare_dt_index(on) | |
| if at_time is not None: | |
| checks.assert_instance_of(on, pd.DatetimeIndex) | |
| on = on.floor("D") | |
| add_time_delta = dt.time_to_timedelta(at_time) | |
| if indexer_tolerance is None: | |
| indexer_method = indexer_method.lower() | |
| if indexer_method in ("pad", "ffill"): | |
| indexer_tolerance = add_time_delta | |
| elif indexer_method in ("backfill", "bfill"): | |
| indexer_tolerance = pd.Timedelta(days=1) - pd.Timedelta(1, "ns") - add_time_delta | |
| if add_delta is None: | |
| add_delta = add_time_delta | |
| else: | |
| add_delta += add_time_delta | |
| if add_delta is not None: | |
| on += dt.to_freq(add_delta) | |
| if kind.lower() == "labels": | |
| on = dt.try_align_to_dt_index(on, index) | |
| if indexer_method is not None: | |
| indexer_method = indexer_method.lower() | |
| if indexer_method == "before": | |
| on = on - pd.Timedelta(1, "ns") | |
| indexer_method = "ffill" | |
| elif indexer_method == "after": | |
| on = on + pd.Timedelta(1, "ns") | |
| indexer_method = "bfill" | |
| index_points = index.get_indexer(on, method=indexer_method, tolerance=indexer_tolerance) | |
| else: | |
| index_points = np.asarray(on) | |
| if start is not None and not start_used: | |
| if not checks.is_int(start): | |
| start = index.get_indexer([start], method="bfill").item(0) | |
| index_points = index_points[index_points >= start] | |
| if end is not None and not end_used: | |
| if not checks.is_int(end): | |
| end = index.get_indexer([end], method="ffill").item(0) | |
| index_points = index_points[index_points <= end] | |
| else: | |
| index_points = index_points[index_points < end] | |
| if skip_not_found: | |
| index_points = index_points[index_points != -1] | |
| return index_points | |
| @define | |
| class RangeIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving index ranges.""" | |
| every: tp.Optional[tp.FrequencyLike] = define.field(default=None) | |
| """Frequency either as an integer or timedelta. | |
| Gets translated into `start` and `end` arrays by creating a range. If integer, an index sequence from `start` | |
| to `end` (exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence | |
| from `start` to `end` (inclusive) is created and 'bounds' as `kind` is used. | |
| If `start_time` and `end_time` are not None and `every`, `start`, and `end` are None, | |
| `every` defaults to one day.""" | |
| normalize_every: bool = define.field(default=False) | |
| """Normalize start/end dates to midnight before generating date range.""" | |
| split_every: bool = define.field(default=True) | |
| """Whether to split the sequence generated using `every` into `start` and `end` arrays. | |
| After creation, and if `split_every` is True, an index range is created from each pair of elements in | |
| the generated sequence. Otherwise, the entire sequence is assigned to `start` and `end`, and only time | |
| and delta instructions can be used to further differentiate between them. | |
| Forced to False if `every`, `start_time`, and `end_time` are not None and `fixed_start` is False.""" | |
| start_time: tp.Optional[tp.TimeLike] = define.field(default=None) | |
| """Start time of the day either as a (human-readable) string or `datetime.time`. | |
| Every datetime in `start` gets floored to the daily frequency, while `start_time` gets converted into | |
| a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_start_delta`. | |
| Index must be datetime-like.""" | |
| end_time: tp.Optional[tp.TimeLike] = define.field(default=None) | |
| """End time of the day either as a (human-readable) string or `datetime.time`. | |
| Every datetime in `end` gets floored to the daily frequency, while `end_time` gets converted into | |
| a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_end_delta`. | |
| Index must be datetime-like.""" | |
| lookback_period: tp.Optional[tp.FrequencyLike] = define.field(default=None) | |
| """Lookback period either as an integer or offset. | |
| If `lookback_period` is set, `start` becomes `end-lookback_period`. If `every` is not None, | |
| the sequence is generated from `start+lookback_period` to `end` and then assigned to `end`. | |
| If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`. | |
| If integer, gets multiplied by the frequency of the index if the index is not integer.""" | |
| start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None) | |
| """Start index/label or a sequence of such. | |
| Gets converted into datetime format whenever possible. | |
| Gets broadcasted together with `end`.""" | |
| end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None) | |
| """End index/label or a sequence of such. | |
| Gets converted into datetime format whenever possible. | |
| Gets broadcasted together with `start`.""" | |
| exact_start: bool = define.field(default=False) | |
| """Whether the first index in the `start` array should be exactly `start`. | |
| Depending on `every`, the first index picked by `pd.date_range` may happen after `start`. | |
| In such a case, `start` gets injected before the first index generated by `pd.date_range`. | |
| Cannot be used together with `lookback_period`.""" | |
| fixed_start: bool = define.field(default=False) | |
| """Whether all indices in the `start` array should be exactly `start`. | |
| Works only together with `every`. | |
| Cannot be used together with `lookback_period`.""" | |
| closed_start: bool = define.field(default=True) | |
| """Whether `start` should be inclusive.""" | |
| closed_end: bool = define.field(default=False) | |
| """Whether `end` should be inclusive.""" | |
| add_start_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None) | |
| """Offset to be added to each in `start`. | |
| If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.""" | |
| add_end_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None) | |
| """Offset to be added to each in `end`. | |
| If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.""" | |
| kind: tp.Optional[str] = define.field(default=None) | |
| """Kind of data in `on`: indices, labels or bounds. | |
| If None, gets assigned to `indices` if `start` and `end` contain integer data, to `bounds` | |
| if `start`, `end`, and index are datetime-like, otherwise to `labels`. | |
| If `kind` is 'labels', `start` and `end` get converted into indices using `pd.Index.get_indexer`. | |
| Prior to this, get their timezone aligned to the timezone of the index. If `kind` is 'indices', | |
| `start` and `end` get wrapped with NumPy. If kind` is 'bounds', | |
| `vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges` is used.""" | |
| skip_not_found: bool = define.field(default=True) | |
| """Whether to drop indices that are -1 (not found).""" | |
| jitted: tp.JittedOption = define.field(default=None) | |
| """Jitting option passed to `vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges`.""" | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if index is None: | |
| raise ValueError("Index is required") | |
| from vectorbtpro.base.merging import column_stack_arrays | |
| start_idxs, end_idxs = get_index_ranges(index, index_freq=freq, **self.asdict()) | |
| idxs = column_stack_arrays((start_idxs, end_idxs)) | |
| self.check_idxs(idxs, check_minus_one=True) | |
| return idxs | |
| range_idxr_defaults = {a.name: a.default for a in RangeIdxr.fields} | |
| def get_index_ranges( | |
| index: tp.Index, | |
| index_freq: tp.Optional[tp.FrequencyLike] = None, | |
| every: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["every"], | |
| normalize_every: bool = range_idxr_defaults["normalize_every"], | |
| split_every: bool = range_idxr_defaults["split_every"], | |
| start_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["start_time"], | |
| end_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["end_time"], | |
| lookback_period: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["lookback_period"], | |
| start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["start"], | |
| end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["end"], | |
| exact_start: bool = range_idxr_defaults["exact_start"], | |
| fixed_start: bool = range_idxr_defaults["fixed_start"], | |
| closed_start: bool = range_idxr_defaults["closed_start"], | |
| closed_end: bool = range_idxr_defaults["closed_end"], | |
| add_start_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_start_delta"], | |
| add_end_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_end_delta"], | |
| kind: tp.Optional[str] = range_idxr_defaults["kind"], | |
| skip_not_found: bool = range_idxr_defaults["skip_not_found"], | |
| jitted: tp.JittedOption = range_idxr_defaults["jitted"], | |
| ) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """Translate indices, labels, or bounds into index ranges. | |
| See `RangeIdxr` for arguments. | |
| Usage: | |
| * Provide nothing to generate one largest index range: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> index = pd.date_range("2020-01", "2020-02", freq="1d") | |
| >>> np.column_stack(vbt.get_index_ranges(index)) | |
| array([[ 0, 32]]) | |
| ``` | |
| * Provide `every` as an integer frequency to generate index ranges using NumPy: | |
| ```pycon | |
| >>> # Generate a range every five rows | |
| >>> np.column_stack(vbt.get_index_ranges(index, every=5)) | |
| array([[ 0, 5], | |
| [ 5, 10], | |
| [10, 15], | |
| [15, 20], | |
| [20, 25], | |
| [25, 30]]) | |
| >>> # Generate a range every five rows, starting at 6th row | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every=5, | |
| ... start=5 | |
| ... )) | |
| array([[ 5, 10], | |
| [10, 15], | |
| [15, 20], | |
| [20, 25], | |
| [25, 30]]) | |
| >>> # Generate a range every five rows from 6th to 16th row | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every=5, | |
| ... start=5, | |
| ... end=15 | |
| ... )) | |
| array([[ 5, 10], | |
| [10, 15]]) | |
| ``` | |
| * Provide `every` as a time delta frequency to generate index ranges using Pandas: | |
| ```pycon | |
| >>> # Generate a range every week | |
| >>> np.column_stack(vbt.get_index_ranges(index, every="W")) | |
| array([[ 4, 11], | |
| [11, 18], | |
| [18, 25]]) | |
| >>> # Generate a range every second day of the week | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... add_start_delta="2d" | |
| ... )) | |
| array([[ 6, 11], | |
| [13, 18], | |
| [20, 25]]) | |
| >>> # Generate a range every week, starting at 11th row | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... start=10 | |
| ... )) | |
| array([[11, 18], | |
| [18, 25]]) | |
| >>> # Generate a range every week, starting exactly at 11th row | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... start=10, | |
| ... exact_start=True | |
| ... )) | |
| array([[10, 11], | |
| [11, 18], | |
| [18, 25]]) | |
| >>> # Generate a range every week, starting at 2020-01-10 | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... start="2020-01-10" | |
| ... )) | |
| array([[11, 18], | |
| [18, 25]]) | |
| >>> # Generate a range every week, each starting at 2020-01-10 | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... start="2020-01-10", | |
| ... fixed_start=True | |
| ... )) | |
| array([[11, 18], | |
| [11, 25]]) | |
| >>> # Generate an expanding range that increments by week | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... start=0, | |
| ... exact_start=True, | |
| ... fixed_start=True | |
| ... )) | |
| array([[ 0, 4], | |
| [ 0, 11], | |
| [ 0, 18], | |
| [ 0, 25]]) | |
| ``` | |
| * Use a look-back period (instead of an end index): | |
| ```pycon | |
| >>> # Generate a range every week, looking 5 days back | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... lookback_period=5 | |
| ... )) | |
| array([[ 6, 11], | |
| [13, 18], | |
| [20, 25]]) | |
| >>> # Generate a range every week, looking 2 weeks back | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... every="W", | |
| ... lookback_period="2W" | |
| ... )) | |
| array([[ 0, 11], | |
| [ 4, 18], | |
| [11, 25]]) | |
| ``` | |
| * Instead of using `every`, provide start and end indices explicitly: | |
| ```pycon | |
| >>> # Generate one range | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... start="2020-01-01", | |
| ... end="2020-01-07" | |
| ... )) | |
| array([[0, 6]]) | |
| >>> # Generate ranges between multiple dates | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... start=["2020-01-01", "2020-01-07"], | |
| ... end=["2020-01-07", "2020-01-14"] | |
| ... )) | |
| array([[ 0, 6], | |
| [ 6, 13]]) | |
| >>> # Generate ranges with a fixed start | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... start="2020-01-01", | |
| ... end=["2020-01-07", "2020-01-14"] | |
| ... )) | |
| array([[ 0, 6], | |
| [ 0, 13]]) | |
| ``` | |
| * Use `closed_start` and `closed_end` to exclude any of the bounds: | |
| ```pycon | |
| >>> # Generate ranges between multiple dates | |
| >>> # by excluding the start date and including the end date | |
| >>> np.column_stack(vbt.get_index_ranges( | |
| ... index, | |
| ... start=["2020-01-01", "2020-01-07"], | |
| ... end=["2020-01-07", "2020-01-14"], | |
| ... closed_start=False, | |
| ... closed_end=True | |
| ... )) | |
| array([[ 1, 7], | |
| [ 7, 14]]) | |
| ``` | |
| """ | |
| from vectorbtpro.base.indexes import repeat_index | |
| from vectorbtpro.base.resampling.base import Resampler | |
| index = dt.prepare_dt_index(index) | |
| if isinstance(index, pd.DatetimeIndex): | |
| if start is not None: | |
| start = dt.try_align_to_dt_index(start, index) | |
| if isinstance(start, pd.DatetimeIndex): | |
| start = start.tz_localize(None) | |
| if end is not None: | |
| end = dt.try_align_to_dt_index(end, index) | |
| if isinstance(end, pd.DatetimeIndex): | |
| end = end.tz_localize(None) | |
| naive_index = index.tz_localize(None) | |
| else: | |
| if start is not None: | |
| if not isinstance(start, pd.Index): | |
| try: | |
| start = pd.Index(start) | |
| except Exception as e: | |
| start = pd.Index([start]) | |
| if end is not None: | |
| if not isinstance(end, pd.Index): | |
| try: | |
| end = pd.Index(end) | |
| except Exception as e: | |
| end = pd.Index([end]) | |
| naive_index = index | |
| if every is not None and not checks.is_int(every): | |
| every = dt.to_freq(every) | |
| if lookback_period is not None and not checks.is_int(lookback_period): | |
| lookback_period = dt.to_freq(lookback_period) | |
| if fixed_start and lookback_period is not None: | |
| raise ValueError("Cannot use fixed_start and lookback_period together") | |
| if exact_start and lookback_period is not None: | |
| raise ValueError("Cannot use exact_start and lookback_period together") | |
| if start_time is not None or end_time is not None: | |
| if every is None and start is None and end is None: | |
| every = pd.Timedelta(days=1) | |
| if every is not None: | |
| if not fixed_start: | |
| if start_time is None and end_time is not None: | |
| start_time = time(0, 0, 0, 0) | |
| closed_start = True | |
| if start_time is not None and end_time is None: | |
| end_time = time(0, 0, 0, 0) | |
| closed_end = False | |
| if start_time is not None and end_time is not None and not fixed_start: | |
| split_every = False | |
| if checks.is_int(every): | |
| if start is None: | |
| start = 0 | |
| else: | |
| start = start[0] | |
| if end is None: | |
| end = len(naive_index) | |
| else: | |
| end = end[-1] | |
| if closed_end: | |
| end -= 1 | |
| if lookback_period is None: | |
| new_index = np.arange(start, end + 1, every) | |
| if not split_every: | |
| start = end = new_index | |
| else: | |
| if fixed_start: | |
| start = np.full(len(new_index) - 1, new_index[0]) | |
| else: | |
| start = new_index[:-1] | |
| end = new_index[1:] | |
| else: | |
| end = np.arange(start + lookback_period, end + 1, every) | |
| start = end - lookback_period | |
| kind = "indices" | |
| lookback_period = None | |
| else: | |
| if start is None: | |
| start = 0 | |
| else: | |
| start = start[0] | |
| if checks.is_int(start): | |
| start_date = naive_index[start] | |
| else: | |
| start_date = start | |
| if end is None: | |
| end = len(naive_index) - 1 | |
| else: | |
| end = end[-1] | |
| if checks.is_int(end): | |
| end_date = naive_index[end] | |
| else: | |
| end_date = end | |
| if lookback_period is None: | |
| new_index = dt.date_range( | |
| start_date, | |
| end_date, | |
| freq=every, | |
| normalize=normalize_every, | |
| inclusive="both", | |
| ) | |
| if exact_start and new_index[0] > start_date: | |
| new_index = new_index.insert(0, start_date) | |
| if not split_every: | |
| start = end = new_index | |
| else: | |
| if fixed_start: | |
| start = repeat_index(new_index[[0]], len(new_index) - 1) | |
| else: | |
| start = new_index[:-1] | |
| end = new_index[1:] | |
| else: | |
| if checks.is_int(lookback_period): | |
| lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq) | |
| if isinstance(lookback_period, BaseOffset): | |
| end = dt.date_range( | |
| start_date, | |
| end_date, | |
| freq=every, | |
| normalize=normalize_every, | |
| inclusive="both", | |
| ) | |
| start = end - lookback_period | |
| start_mask = start >= start_date | |
| start = start[start_mask] | |
| end = end[start_mask] | |
| else: | |
| end = dt.date_range( | |
| start_date + lookback_period, | |
| end_date, | |
| freq=every, | |
| normalize=normalize_every, | |
| inclusive="both", | |
| ) | |
| start = end - lookback_period | |
| kind = "bounds" | |
| lookback_period = None | |
| if kind is None: | |
| if start is None and end is None: | |
| kind = "indices" | |
| else: | |
| if start is not None: | |
| ref_index = start | |
| if end is not None: | |
| ref_index = end | |
| if pd.api.types.is_integer_dtype(ref_index): | |
| kind = "indices" | |
| elif isinstance(ref_index, pd.DatetimeIndex) and isinstance(naive_index, pd.DatetimeIndex): | |
| kind = "bounds" | |
| else: | |
| kind = "labels" | |
| checks.assert_in(kind, ("indices", "labels", "bounds")) | |
| if end is None: | |
| if kind.lower() in ("labels", "bounds"): | |
| end = pd.Index([naive_index[-1]]) | |
| else: | |
| end = pd.Index([len(naive_index)]) | |
| if start is not None and lookback_period is not None: | |
| raise ValueError("Cannot use start and lookback_period together") | |
| if start is None: | |
| if lookback_period is None: | |
| if kind.lower() in ("labels", "bounds"): | |
| start = pd.Index([naive_index[0]]) | |
| else: | |
| start = pd.Index([0]) | |
| else: | |
| if checks.is_int(lookback_period) and not pd.api.types.is_integer_dtype(end): | |
| lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq) | |
| start = end - lookback_period | |
| if len(start) == 1 and len(end) > 1: | |
| start = repeat_index(start, len(end)) | |
| elif len(start) > 1 and len(end) == 1: | |
| end = repeat_index(end, len(start)) | |
| checks.assert_len_equal(start, end) | |
| if start_time is not None: | |
| checks.assert_instance_of(start, pd.DatetimeIndex) | |
| start = start.floor("D") | |
| add_start_time_delta = dt.time_to_timedelta(start_time) | |
| if add_start_delta is None: | |
| add_start_delta = add_start_time_delta | |
| else: | |
| add_start_delta += add_start_time_delta | |
| else: | |
| add_start_time_delta = None | |
| if end_time is not None: | |
| checks.assert_instance_of(end, pd.DatetimeIndex) | |
| end = end.floor("D") | |
| add_end_time_delta = dt.time_to_timedelta(end_time) | |
| if add_start_time_delta is not None: | |
| if add_end_time_delta < add_start_delta: | |
| add_end_time_delta += pd.Timedelta(days=1) | |
| if add_end_delta is None: | |
| add_end_delta = add_end_time_delta | |
| else: | |
| add_end_delta += add_end_time_delta | |
| if add_start_delta is not None: | |
| start += dt.to_freq(add_start_delta) | |
| if add_end_delta is not None: | |
| end += dt.to_freq(add_end_delta) | |
| if kind.lower() == "bounds": | |
| range_starts, range_ends = Resampler.map_bounds_to_source_ranges( | |
| source_index=naive_index.values, | |
| target_lbound_index=start.values, | |
| target_rbound_index=end.values, | |
| closed_lbound=closed_start, | |
| closed_rbound=closed_end, | |
| skip_not_found=skip_not_found, | |
| jitted=jitted, | |
| ) | |
| else: | |
| if kind.lower() == "labels": | |
| range_starts = np.empty(len(start), dtype=int_) | |
| range_ends = np.empty(len(end), dtype=int_) | |
| range_index = pd.Series(np.arange(len(naive_index)), index=naive_index) | |
| for i in range(len(range_starts)): | |
| selected_range = range_index[start[i] : end[i]] | |
| if len(selected_range) > 0 and not closed_start and selected_range.index[0] == start[i]: | |
| selected_range = selected_range.iloc[1:] | |
| if len(selected_range) > 0 and not closed_end and selected_range.index[-1] == end[i]: | |
| selected_range = selected_range.iloc[:-1] | |
| if len(selected_range) > 0: | |
| range_starts[i] = selected_range.iloc[0] | |
| range_ends[i] = selected_range.iloc[-1] | |
| else: | |
| range_starts[i] = -1 | |
| range_ends[i] = -1 | |
| else: | |
| if not closed_start: | |
| start = start + 1 | |
| if closed_end: | |
| end = end + 1 | |
| range_starts = np.asarray(start) | |
| range_ends = np.asarray(end) | |
| if skip_not_found: | |
| valid_mask = (range_starts != -1) & (range_ends != -1) | |
| range_starts = range_starts[valid_mask] | |
| range_ends = range_ends[valid_mask] | |
| if np.any(range_starts >= range_ends): | |
| raise ValueError("Some start indices are equal to or higher than end indices") | |
| return range_starts, range_ends | |
| @define | |
| class AutoIdxr(UniIdxr, DefineMixin): | |
| """Class for resolving indices, datetime-like objects, frequency-like objects, and labels for one axis.""" | |
| value: tp.Union[ | |
| None, | |
| tp.PosSel, | |
| tp.LabelSel, | |
| tp.MaybeSequence[tp.MaybeSequence[int]], | |
| tp.MaybeSequence[tp.Label], | |
| tp.MaybeSequence[tp.DatetimeLike], | |
| tp.MaybeSequence[tp.DTCLike], | |
| tp.FrequencyLike, | |
| tp.Slice, | |
| ] = define.field() | |
| """One or more integer indices, datetime-like objects, frequency-like objects, or labels. | |
| Can also be an instance of `vectorbtpro.utils.selection.PosSel` holding position(s) | |
| and `vectorbtpro.utils.selection.LabelSel` holding label(s).""" | |
| closed_start: bool = define.optional_field() | |
| """Whether slice start should be inclusive.""" | |
| closed_end: bool = define.optional_field() | |
| """Whether slice end should be inclusive.""" | |
| indexer_method: tp.Optional[str] = define.optional_field() | |
| """Method for `pd.Index.get_indexer`.""" | |
| below_to_zero: bool = define.optional_field() | |
| """Whether to place 0 instead of -1 if `AutoIdxr.value` is below the first index.""" | |
| above_to_len: bool = define.optional_field() | |
| """Whether to place `len(index)` instead of -1 if `AutoIdxr.value` is above the last index.""" | |
| level: tp.MaybeLevelSequence = define.field(default=None) | |
| """One or more levels. | |
| If `level` is not None and `kind` is None, `kind` becomes "labels".""" | |
| kind: tp.Optional[str] = define.field(default=None) | |
| """Kind of value. | |
| Allowed are | |
| * "position(s)" for `PosIdxr` | |
| * "mask" for `MaskIdxr` | |
| * "label(s)" for `LabelIdxr` | |
| * "datetime" for `DatetimeIdxr` | |
| * "dtc": for `DTCIdxr` | |
| * "frequency" for `PointIdxr` | |
| If None, will (try to) determine automatically based on the type of indices.""" | |
| idxr_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to the selected indexer.""" | |
| def __init__(self, *args, **kwargs) -> None: | |
| idxr_kwargs = kwargs.pop("idxr_kwargs", None) | |
| if idxr_kwargs is None: | |
| idxr_kwargs = {} | |
| else: | |
| idxr_kwargs = dict(idxr_kwargs) | |
| builtin_keys = {a.name for a in self.fields} | |
| for k in list(kwargs.keys()): | |
| if k not in builtin_keys: | |
| idxr_kwargs[k] = kwargs.pop(k) | |
| DefineMixin.__init__(self, *args, idxr_kwargs=idxr_kwargs, **kwargs) | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| ) -> tp.MaybeIndexArray: | |
| if self.value is None: | |
| return slice(None, None, None) | |
| value = self.value | |
| kind = self.kind | |
| if self.level is not None: | |
| from vectorbtpro.base.indexes import select_levels | |
| if index is None: | |
| raise ValueError("Index is required") | |
| index = select_levels(index, self.level) | |
| if kind is None: | |
| kind = "labels" | |
| if self.idxr_kwargs is None: | |
| idxr_kwargs = self.idxr_kwargs | |
| else: | |
| idxr_kwargs = None | |
| if idxr_kwargs is None: | |
| idxr_kwargs = {} | |
| def _dtc_check_func(dtc): | |
| return ( | |
| not dtc.has_full_datetime() | |
| and self.indexer_method in (MISSING, None) | |
| and self.below_to_zero is MISSING | |
| and self.above_to_len is MISSING | |
| ) | |
| if kind is None: | |
| if isinstance(value, PosSel): | |
| kind = "positions" | |
| value = value.value | |
| elif isinstance(value, LabelSel): | |
| kind = "labels" | |
| value = value.value | |
| elif isinstance(value, (slice, hslice)): | |
| if checks.is_int(value.start) or checks.is_int(value.stop): | |
| kind = "positions" | |
| elif value.start is None and value.stop is None: | |
| kind = "positions" | |
| else: | |
| if index is None: | |
| raise ValueError("Index is required") | |
| if isinstance(index, pd.DatetimeIndex): | |
| if dt.DTC.is_parsable(value.start, check_func=_dtc_check_func) or dt.DTC.is_parsable( | |
| value.stop, check_func=_dtc_check_func | |
| ): | |
| kind = "dtc" | |
| else: | |
| kind = "datetime" | |
| else: | |
| kind = "labels" | |
| elif (checks.is_sequence(value) and not np.isscalar(value)) and ( | |
| index is None | |
| or ( | |
| not isinstance(index, pd.MultiIndex) | |
| or (isinstance(index, pd.MultiIndex) and isinstance(value[0], tuple)) | |
| ) | |
| ): | |
| if checks.is_bool(value[0]): | |
| kind = "mask" | |
| elif checks.is_int(value[0]): | |
| kind = "positions" | |
| elif ( | |
| (index is None or not isinstance(index, pd.MultiIndex) or not isinstance(value[0], tuple)) | |
| and checks.is_sequence(value[0]) | |
| and len(value[0]) == 2 | |
| and checks.is_int(value[0][0]) | |
| and checks.is_int(value[0][1]) | |
| ): | |
| kind = "positions" | |
| else: | |
| if index is None: | |
| raise ValueError("Index is required") | |
| elif isinstance(index, pd.DatetimeIndex): | |
| if dt.DTC.is_parsable(value[0], check_func=_dtc_check_func): | |
| kind = "dtc" | |
| else: | |
| kind = "datetime" | |
| else: | |
| kind = "labels" | |
| else: | |
| if checks.is_bool(value): | |
| kind = "mask" | |
| elif checks.is_int(value): | |
| kind = "positions" | |
| else: | |
| if index is None: | |
| raise ValueError("Index is required") | |
| if isinstance(index, pd.DatetimeIndex): | |
| if dt.DTC.is_parsable(value, check_func=_dtc_check_func): | |
| kind = "dtc" | |
| elif isinstance(value, str): | |
| try: | |
| if not value.isupper() and not value.islower(): | |
| raise Exception # "2020" shouldn't be a frequency | |
| _ = dt.to_freq(value) | |
| kind = "frequency" | |
| except Exception as e: | |
| try: | |
| _ = dt.to_timestamp(value) | |
| kind = "datetime" | |
| except Exception as e: | |
| raise ValueError(f"'{value}' is neither a frequency nor a datetime") | |
| elif checks.is_frequency(value): | |
| kind = "frequency" | |
| else: | |
| kind = "datetime" | |
| else: | |
| kind = "labels" | |
| def _expand_target_kwargs(target_cls, **target_kwargs): | |
| source_arg_names = {a.name for a in self.fields if a.default is MISSING} | |
| target_arg_names = {a.name for a in target_cls.fields} | |
| for arg_name in source_arg_names: | |
| if arg_name in target_arg_names: | |
| arg_value = getattr(self, arg_name) | |
| if arg_value is not MISSING: | |
| target_kwargs[arg_name] = arg_value | |
| return target_kwargs | |
| if kind.lower() in ("position", "positions"): | |
| idx = PosIdxr(value, **_expand_target_kwargs(PosIdxr, **idxr_kwargs)) | |
| elif kind.lower() == "mask": | |
| idx = MaskIdxr(value, **_expand_target_kwargs(MaskIdxr, **idxr_kwargs)) | |
| elif kind.lower() in ("label", "labels"): | |
| idx = LabelIdxr(value, **_expand_target_kwargs(LabelIdxr, **idxr_kwargs)) | |
| elif kind.lower() == "datetime": | |
| idx = DatetimeIdxr(value, **_expand_target_kwargs(DatetimeIdxr, **idxr_kwargs)) | |
| elif kind.lower() == "dtc": | |
| idx = DTCIdxr(value, **_expand_target_kwargs(DTCIdxr, **idxr_kwargs)) | |
| elif kind.lower() == "frequency": | |
| idx = PointIdxr(every=value, **_expand_target_kwargs(PointIdxr, **idxr_kwargs)) | |
| else: | |
| raise ValueError(f"Invalid kind: '{kind}'") | |
| return idx.get(index=index, freq=freq) | |
| @define | |
| class RowIdxr(IdxrBase, DefineMixin): | |
| """Class for resolving row indices.""" | |
| idxr: object = define.field() | |
| """Indexer. | |
| Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`.""" | |
| idxr_kwargs: tp.KwargsLike = define.field() | |
| """Keyword arguments passed to `AutoIdxr`.""" | |
| def __init__(self, idxr: object, **idxr_kwargs) -> None: | |
| DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs)) | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| template_context: tp.KwargsLike = None, | |
| ) -> tp.MaybeIndexArray: | |
| idxr = self.idxr | |
| if isinstance(idxr, CustomTemplate): | |
| _template_context = merge_dicts(dict(index=index, freq=freq), template_context) | |
| idxr = idxr.substitute(_template_context, eval_id="idxr") | |
| if not isinstance(idxr, UniIdxr): | |
| if isinstance(idxr, IdxrBase): | |
| raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr") | |
| idxr = AutoIdxr(idxr, **self.idxr_kwargs) | |
| return idxr.get(index=index, freq=freq) | |
| @define | |
| class ColIdxr(IdxrBase, DefineMixin): | |
| """Class for resolving column indices.""" | |
| idxr: object = define.field() | |
| """Indexer. | |
| Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`.""" | |
| idxr_kwargs: tp.KwargsLike = define.field() | |
| """Keyword arguments passed to `AutoIdxr`.""" | |
| def __init__(self, idxr: object, **idxr_kwargs) -> None: | |
| DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs)) | |
| def get( | |
| self, | |
| columns: tp.Optional[tp.Index] = None, | |
| template_context: tp.KwargsLike = None, | |
| ) -> tp.MaybeIndexArray: | |
| idxr = self.idxr | |
| if isinstance(idxr, CustomTemplate): | |
| _template_context = merge_dicts(dict(columns=columns), template_context) | |
| idxr = idxr.substitute(_template_context, eval_id="idxr") | |
| if not isinstance(idxr, UniIdxr): | |
| if isinstance(idxr, IdxrBase): | |
| raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr") | |
| idxr = AutoIdxr(idxr, **self.idxr_kwargs) | |
| return idxr.get(index=columns) | |
| @define | |
| class Idxr(IdxrBase, DefineMixin): | |
| """Class for resolving indices.""" | |
| idxrs: tp.Tuple[object, ...] = define.field() | |
| """A tuple of one or more indexers. | |
| If one indexer is provided, can be an instance of `RowIdxr` or `ColIdxr`, | |
| a custom template, or a value to wrapped with `RowIdxr`. | |
| If two indexers are provided, can be an instance of `RowIdxr` and `ColIdxr` respectively, | |
| or a value to wrapped with `RowIdxr` and `ColIdxr` respectively.""" | |
| idxr_kwargs: tp.KwargsLike = define.field() | |
| """Keyword arguments passed to `RowIdxr` and `ColIdxr`.""" | |
| def __init__(self, *idxrs: object, **idxr_kwargs) -> None: | |
| DefineMixin.__init__(self, idxrs=idxrs, idxr_kwargs=hdict(idxr_kwargs)) | |
| def get( | |
| self, | |
| index: tp.Optional[tp.Index] = None, | |
| columns: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| template_context: tp.KwargsLike = None, | |
| ) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]: | |
| if len(self.idxrs) == 0: | |
| raise ValueError("Must provide at least one indexer") | |
| elif len(self.idxrs) == 1: | |
| idxr = self.idxrs[0] | |
| if isinstance(idxr, CustomTemplate): | |
| _template_context = merge_dicts(dict(index=index, columns=columns, freq=freq), template_context) | |
| idxr = idxr.substitute(_template_context, eval_id="idxr") | |
| if isinstance(idxr, tuple): | |
| return type(self)(*idxr).get( | |
| index=index, | |
| columns=columns, | |
| freq=freq, | |
| template_context=template_context, | |
| ) | |
| return type(self)(idxr).get( | |
| index=index, | |
| columns=columns, | |
| freq=freq, | |
| template_context=template_context, | |
| ) | |
| if isinstance(idxr, ColIdxr): | |
| row_idxr = None | |
| col_idxr = idxr | |
| else: | |
| row_idxr = idxr | |
| col_idxr = None | |
| elif len(self.idxrs) == 2: | |
| row_idxr = self.idxrs[0] | |
| col_idxr = self.idxrs[1] | |
| else: | |
| raise ValueError("Must provide at most two indexers") | |
| if not isinstance(row_idxr, RowIdxr): | |
| if isinstance(row_idxr, (ColIdxr, Idxr)): | |
| raise TypeError(f"Indexer {type(row_idxr)} not supported as a row indexer") | |
| row_idxr = RowIdxr(row_idxr, **self.idxr_kwargs) | |
| row_idxs = row_idxr.get(index=index, freq=freq, template_context=template_context) | |
| if not isinstance(col_idxr, ColIdxr): | |
| if isinstance(col_idxr, (RowIdxr, Idxr)): | |
| raise TypeError(f"Indexer {type(col_idxr)} not supported as a column indexer") | |
| col_idxr = ColIdxr(col_idxr, **self.idxr_kwargs) | |
| col_idxs = col_idxr.get(columns=columns, template_context=template_context) | |
| return row_idxs, col_idxs | |
| def get_idxs( | |
| idxr: object, | |
| index: tp.Optional[tp.Index] = None, | |
| columns: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| template_context: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]: | |
| """Translate indexer to row and column indices. | |
| If `idxr` is not an indexer class, wraps it with `Idxr`. | |
| Keyword arguments are passed when constructing a new `Idxr`.""" | |
| if not isinstance(idxr, Idxr): | |
| idxr = Idxr(idxr, **kwargs) | |
| return idxr.get(index=index, columns=columns, freq=freq, template_context=template_context) | |
| class index_dict(pdict): | |
| """Dict that contains indexer objects as keys and values to be set as values. | |
| Each indexer object must be hashable. To make a slice hashable, use `hslice`. | |
| To make an array hashable, convert it into a tuple. | |
| To set a default value, use the `_def` key (case-sensitive!).""" | |
| pass | |
| IdxSetterT = tp.TypeVar("IdxSetterT", bound="IdxSetter") | |
| @define | |
| class IdxSetter(DefineMixin): | |
| """Class for setting values based on indexing.""" | |
| idx_items: tp.List[tp.Tuple[object, tp.ArrayLike]] = define.field() | |
| """Items where the first element is an indexer and the second element is a value to be set.""" | |
| @classmethod | |
| def set_row_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None: | |
| """Set row indices in an array.""" | |
| from vectorbtpro.base.reshaping import broadcast_array_to | |
| if not isinstance(v, np.ndarray): | |
| v = np.asarray(v) | |
| single_v = v.size == 1 or (v.ndim == 2 and v.shape[0] == 1) | |
| if arr.ndim == 2: | |
| single_row = not isinstance(idxs, slice) and (np.isscalar(idxs) or idxs.size == 1) | |
| if not single_row: | |
| if v.ndim == 1 and v.size > 1: | |
| v = v[:, None] | |
| if isinstance(idxs, np.ndarray) and idxs.ndim == 2: | |
| if not single_v: | |
| if arr.ndim == 2: | |
| v = broadcast_array_to(v, (len(idxs), arr.shape[1])) | |
| else: | |
| v = broadcast_array_to(v, (len(idxs),)) | |
| for i in range(len(idxs)): | |
| _slice = slice(idxs[i, 0], idxs[i, 1]) | |
| if not single_v: | |
| cls.set_row_idxs(arr, _slice, v[[i]]) | |
| else: | |
| cls.set_row_idxs(arr, _slice, v) | |
| else: | |
| arr[idxs] = v | |
| @classmethod | |
| def set_col_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None: | |
| """Set column indices in an array.""" | |
| from vectorbtpro.base.reshaping import broadcast_array_to | |
| if not isinstance(v, np.ndarray): | |
| v = np.asarray(v) | |
| single_v = v.size == 1 or (v.ndim == 2 and v.shape[1] == 1) | |
| if isinstance(idxs, np.ndarray) and idxs.ndim == 2: | |
| if not single_v: | |
| v = broadcast_array_to(v, (arr.shape[0], len(idxs))) | |
| for j in range(len(idxs)): | |
| _slice = slice(idxs[j, 0], idxs[j, 1]) | |
| if not single_v: | |
| cls.set_col_idxs(arr, _slice, v[:, [j]]) | |
| else: | |
| cls.set_col_idxs(arr, _slice, v) | |
| else: | |
| arr[:, idxs] = v | |
| @classmethod | |
| def set_row_and_col_idxs( | |
| cls, | |
| arr: tp.Array, | |
| row_idxs: tp.MaybeIndexArray, | |
| col_idxs: tp.MaybeIndexArray, | |
| v: tp.Any, | |
| ) -> None: | |
| """Set row and column indices in an array.""" | |
| from vectorbtpro.base.reshaping import broadcast_array_to | |
| if not isinstance(v, np.ndarray): | |
| v = np.asarray(v) | |
| single_v = v.size == 1 | |
| if ( | |
| isinstance(row_idxs, np.ndarray) | |
| and row_idxs.ndim == 2 | |
| and isinstance(col_idxs, np.ndarray) | |
| and col_idxs.ndim == 2 | |
| ): | |
| if not single_v: | |
| v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) | |
| for i in range(len(row_idxs)): | |
| for j in range(len(col_idxs)): | |
| row_slice = slice(row_idxs[i, 0], row_idxs[i, 1]) | |
| col_slice = slice(col_idxs[j, 0], col_idxs[j, 1]) | |
| if not single_v: | |
| cls.set_row_and_col_idxs(arr, row_slice, col_slice, v[i, j]) | |
| else: | |
| cls.set_row_and_col_idxs(arr, row_slice, col_slice, v) | |
| elif isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2: | |
| if not single_v: | |
| if isinstance(col_idxs, slice): | |
| col_idxs = np.arange(arr.shape[1])[col_idxs] | |
| v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) | |
| for i in range(len(row_idxs)): | |
| row_slice = slice(row_idxs[i, 0], row_idxs[i, 1]) | |
| if not single_v: | |
| cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v[[i]]) | |
| else: | |
| cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v) | |
| elif isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2: | |
| if not single_v: | |
| if isinstance(row_idxs, slice): | |
| row_idxs = np.arange(arr.shape[0])[row_idxs] | |
| v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) | |
| for j in range(len(col_idxs)): | |
| col_slice = slice(col_idxs[j, 0], col_idxs[j, 1]) | |
| if not single_v: | |
| cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v[:, [j]]) | |
| else: | |
| cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v) | |
| else: | |
| if np.isscalar(row_idxs) or np.isscalar(col_idxs): | |
| arr[row_idxs, col_idxs] = v | |
| elif np.isscalar(v) and (isinstance(row_idxs, slice) or isinstance(col_idxs, slice)): | |
| arr[row_idxs, col_idxs] = v | |
| elif np.isscalar(v): | |
| arr[np.ix_(row_idxs, col_idxs)] = v | |
| else: | |
| if isinstance(row_idxs, slice): | |
| row_idxs = np.arange(arr.shape[0])[row_idxs] | |
| if isinstance(col_idxs, slice): | |
| col_idxs = np.arange(arr.shape[1])[col_idxs] | |
| v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) | |
| arr[np.ix_(row_idxs, col_idxs)] = v | |
| def get_set_meta( | |
| self, | |
| shape: tp.ShapeLike, | |
| index: tp.Optional[tp.Index] = None, | |
| columns: tp.Optional[tp.Index] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| template_context: tp.KwargsLike = None, | |
| ) -> tp.Kwargs: | |
| """Get meta of setting operations in `IdxSetter.idx_items`.""" | |
| from vectorbtpro.base.reshaping import to_tuple_shape | |
| shape = to_tuple_shape(shape) | |
| rows_changed = False | |
| cols_changed = False | |
| set_funcs = [] | |
| default = None | |
| for idxr, v in self.idx_items: | |
| if isinstance(idxr, str) and idxr == "_def": | |
| if default is None: | |
| default = v | |
| continue | |
| row_idxs, col_idxs = get_idxs( | |
| idxr, | |
| index=index, | |
| columns=columns, | |
| freq=freq, | |
| template_context=template_context, | |
| ) | |
| if isinstance(v, CustomTemplate): | |
| _template_context = merge_dicts( | |
| dict( | |
| idxr=idxr, | |
| row_idxs=row_idxs, | |
| col_idxs=col_idxs, | |
| ), | |
| template_context, | |
| ) | |
| v = v.substitute(_template_context, eval_id="set") | |
| if not isinstance(v, np.ndarray): | |
| v = np.asarray(v) | |
| def _check_use_idxs(idxs): | |
| use_idxs = True | |
| if isinstance(idxs, slice): | |
| if idxs.start is None and idxs.stop is None and idxs.step is None: | |
| use_idxs = False | |
| if isinstance(idxs, np.ndarray): | |
| if idxs.size == 0: | |
| use_idxs = False | |
| return use_idxs | |
| use_row_idxs = _check_use_idxs(row_idxs) | |
| use_col_idxs = _check_use_idxs(col_idxs) | |
| if use_row_idxs and use_col_idxs: | |
| set_funcs.append(partial(self.set_row_and_col_idxs, row_idxs=row_idxs, col_idxs=col_idxs, v=v)) | |
| rows_changed = True | |
| cols_changed = True | |
| elif use_col_idxs: | |
| set_funcs.append(partial(self.set_col_idxs, idxs=col_idxs, v=v)) | |
| if checks.is_int(col_idxs): | |
| if v.size > 1: | |
| rows_changed = True | |
| else: | |
| if v.ndim == 2: | |
| if v.shape[0] > 1: | |
| rows_changed = True | |
| cols_changed = True | |
| else: | |
| set_funcs.append(partial(self.set_row_idxs, idxs=row_idxs, v=v)) | |
| if use_row_idxs: | |
| rows_changed = True | |
| if len(shape) == 2: | |
| if checks.is_int(row_idxs): | |
| if v.size > 1: | |
| cols_changed = True | |
| else: | |
| if v.ndim == 2: | |
| if v.shape[1] > 1: | |
| cols_changed = True | |
| return dict( | |
| default=default, | |
| set_funcs=set_funcs, | |
| rows_changed=rows_changed, | |
| cols_changed=cols_changed, | |
| ) | |
| def set(self, arr: tp.Array, set_funcs: tp.Optional[tp.Sequence[tp.Callable]] = None, **kwargs) -> None: | |
| """Set values of a NumPy array based on `IdxSetter.get_set_meta`.""" | |
| if set_funcs is None: | |
| set_meta = self.get_set_meta(arr.shape, **kwargs) | |
| set_funcs = set_meta["set_funcs"] | |
| for set_op in set_funcs: | |
| set_op(arr) | |
| def set_pd(self, pd_arr: tp.SeriesFrame, **kwargs) -> None: | |
| """Set values of a Pandas array based on `IdxSetter.get_set_meta`.""" | |
| from vectorbtpro.base.indexes import get_index | |
| index = get_index(pd_arr, 0) | |
| columns = get_index(pd_arr, 1) | |
| freq = dt.infer_index_freq(index) | |
| self.set(pd_arr.values, index=index, columns=columns, freq=freq, **kwargs) | |
| def fill_and_set( | |
| self, | |
| shape: tp.ShapeLike, | |
| keep_flex: bool = False, | |
| fill_value: tp.Scalar = np.nan, | |
| **kwargs, | |
| ) -> tp.Array: | |
| """Fill a new array and set its values based on `IdxSetter.get_set_meta`. | |
| If `keep_flex` is True, will return the most memory-efficient array representation | |
| capable of flexible indexing. | |
| If `fill_value` is None, will search for the `_def` key in `IdxSetter.idx_items`. | |
| If there's none, will be set to NaN.""" | |
| set_meta = self.get_set_meta(shape, **kwargs) | |
| if set_meta["default"] is not None: | |
| fill_value = set_meta["default"] | |
| if isinstance(fill_value, str): | |
| dtype = object | |
| else: | |
| dtype = None | |
| if keep_flex and not set_meta["cols_changed"] and not set_meta["rows_changed"]: | |
| arr = np.full((1,) if len(shape) == 1 else (1, 1), fill_value, dtype=dtype) | |
| elif keep_flex and not set_meta["cols_changed"]: | |
| arr = np.full(shape if len(shape) == 1 else (shape[0], 1), fill_value, dtype=dtype) | |
| elif keep_flex and not set_meta["rows_changed"]: | |
| arr = np.full((1, shape[1]), fill_value, dtype=dtype) | |
| else: | |
| arr = np.full(shape, fill_value, dtype=dtype) | |
| self.set(arr, set_funcs=set_meta["set_funcs"]) | |
| return arr | |
| class IdxSetterFactory(Base): | |
| """Class for building index setters.""" | |
| def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: | |
| """Get an instance of `IdxSetter` or a dict of such instances - one per array name.""" | |
| raise NotImplementedError | |
| @define | |
| class IdxDict(IdxSetterFactory, DefineMixin): | |
| """Class for building an index setter from a dict.""" | |
| index_dct: dict = define.field() | |
| """Dict that contains indexer objects as keys and values to be set as values.""" | |
| def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: | |
| return IdxSetter(list(self.index_dct.items())) | |
| @define | |
| class IdxSeries(IdxSetterFactory, DefineMixin): | |
| """Class for building an index setter from a Series.""" | |
| sr: tp.AnyArray1d = define.field() | |
| """Series or any array-like object to create the Series from.""" | |
| split: bool = define.field(default=False) | |
| """Whether to split the setting operation. | |
| If False, will set all values using a single operation. | |
| Otherwise, will do one operation per element.""" | |
| idx_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to `idx` if the indexer isn't an instance of `Idxr`.""" | |
| def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: | |
| sr = self.sr | |
| split = self.split | |
| idx_kwargs = self.idx_kwargs | |
| if idx_kwargs is None: | |
| idx_kwargs = {} | |
| if not isinstance(sr, pd.Series): | |
| sr = pd.Series(sr) | |
| if split: | |
| idx_items = list(sr.items()) | |
| else: | |
| idx_items = [(sr.index, sr.values)] | |
| new_idx_items = [] | |
| for idxr, v in idx_items: | |
| if idxr is None: | |
| raise ValueError("Indexer cannot be None") | |
| if not isinstance(idxr, Idxr): | |
| idxr = idx(idxr, **idx_kwargs) | |
| new_idx_items.append((idxr, v)) | |
| return IdxSetter(new_idx_items) | |
| @define | |
| class IdxFrame(IdxSetterFactory, DefineMixin): | |
| """Class for building an index setter from a DataFrame.""" | |
| df: tp.AnyArray2d = define.field() | |
| """DataFrame or any array-like object to create the DataFrame from.""" | |
| split: tp.Union[bool, str] = define.field(default=False) | |
| """Whether to split the setting operation. | |
| If False, will set all values using a single operation. | |
| Otherwise, the following options are supported: | |
| * 'columns': one operation per column | |
| * 'rows': one operation per row | |
| * True or 'elements': one operation per element""" | |
| rowidx_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`.""" | |
| colidx_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`.""" | |
| def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: | |
| df = self.df | |
| split = self.split | |
| rowidx_kwargs = self.rowidx_kwargs | |
| colidx_kwargs = self.colidx_kwargs | |
| if rowidx_kwargs is None: | |
| rowidx_kwargs = {} | |
| if colidx_kwargs is None: | |
| colidx_kwargs = {} | |
| if not isinstance(df, pd.DataFrame): | |
| df = pd.DataFrame(df) | |
| if isinstance(split, bool): | |
| if split: | |
| split = "elements" | |
| else: | |
| split = None | |
| if split is not None: | |
| if split.lower() == "columns": | |
| idx_items = [] | |
| for col, sr in df.items(): | |
| idx_items.append((sr.index, col, sr.values)) | |
| elif split.lower() == "rows": | |
| idx_items = [] | |
| for row, sr in df.iterrows(): | |
| idx_items.append((row, df.columns, sr.values)) | |
| elif split.lower() == "elements": | |
| idx_items = [] | |
| for col, sr in df.items(): | |
| for row, v in sr.items(): | |
| idx_items.append((row, col, v)) | |
| else: | |
| raise ValueError(f"Invalid split: '{split}'") | |
| else: | |
| idx_items = [(df.index, df.columns, df.values)] | |
| new_idx_items = [] | |
| for row_idxr, col_idxr, v in idx_items: | |
| if row_idxr is None: | |
| raise ValueError("Row indexer cannot be None") | |
| if col_idxr is None: | |
| raise ValueError("Column indexer cannot be None") | |
| if row_idxr is not None and not isinstance(row_idxr, RowIdxr): | |
| row_idxr = rowidx(row_idxr, **rowidx_kwargs) | |
| if col_idxr is not None and not isinstance(col_idxr, ColIdxr): | |
| col_idxr = colidx(col_idxr, **colidx_kwargs) | |
| new_idx_items.append((idx(row_idxr, col_idxr), v)) | |
| return IdxSetter(new_idx_items) | |
| @define | |
| class IdxRecords(IdxSetterFactory, DefineMixin): | |
| """Class for building index setters from records - one per field.""" | |
| records: tp.RecordsLike = define.field() | |
| """Series, DataFrame, or any sequence of mapping-like objects. | |
| If a Series or DataFrame and the index is not a default range, the index will become a row field. | |
| If a custom row field is provided, the index will be ignored.""" | |
| row_field: tp.Union[None, bool, tp.Label] = define.field(default=None) | |
| """Row field. | |
| If None or True, will search for "row", "index", "open time", and "date" (case-insensitive). | |
| If `IdxRecords.records` is a Series or DataFrame, will also include the index name | |
| if the index is not a default range. | |
| If a record doesn't have a row field, all rows will be set. | |
| If there's no row and column field, the field value will become the default of the entire array.""" | |
| col_field: tp.Union[None, bool, tp.Label] = define.field(default=None) | |
| """Column field. | |
| If None or True, will search for "col", "column", and "symbol" (case-insensitive). | |
| If a record doesn't have a column field, all columns will be set. | |
| If there's no row and column field, the field value will become the default of the entire array.""" | |
| rowidx_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`.""" | |
| colidx_kwargs: tp.KwargsLike = define.field(default=None) | |
| """Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`.""" | |
| def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: | |
| records = self.records | |
| row_field = self.row_field | |
| col_field = self.col_field | |
| rowidx_kwargs = self.rowidx_kwargs | |
| colidx_kwargs = self.colidx_kwargs | |
| if rowidx_kwargs is None: | |
| rowidx_kwargs = {} | |
| if colidx_kwargs is None: | |
| colidx_kwargs = {} | |
| default_index = False | |
| index_field = None | |
| if isinstance(records, pd.Series): | |
| records = records.to_frame() | |
| if isinstance(records, pd.DataFrame): | |
| records = records | |
| if checks.is_default_index(records.index): | |
| default_index = True | |
| records = records.reset_index(drop=default_index) | |
| if not default_index: | |
| index_field = records.columns[0] | |
| records = records.itertuples(index=False) | |
| def _resolve_field_meta(fields): | |
| _row_field = row_field | |
| _row_kind = None | |
| _col_field = col_field | |
| _col_kind = None | |
| row_fields = set() | |
| col_fields = set() | |
| for field in fields: | |
| if isinstance(field, str) and index_field is not None and field == index_field: | |
| row_fields.add((field, None)) | |
| if isinstance(field, str) and field.lower() in ("row", "index"): | |
| row_fields.add((field, None)) | |
| if isinstance(field, str) and field.lower() in ("open time", "date", "datetime"): | |
| if (field, None) in row_fields: | |
| row_fields.remove((field, None)) | |
| row_fields.add((field, "datetime")) | |
| if isinstance(field, str) and field.lower() in ("col", "column"): | |
| col_fields.add((field, None)) | |
| if isinstance(field, str) and field.lower() == "symbol": | |
| if (field, None) in col_fields: | |
| col_fields.remove((field, None)) | |
| col_fields.add((field, "labels")) | |
| if _row_field in (None, True): | |
| if len(row_fields) == 0: | |
| if _row_field is True: | |
| raise ValueError("Cannot find row field") | |
| _row_field = None | |
| elif len(row_fields) == 1: | |
| _row_field, _row_kind = row_fields.pop() | |
| else: | |
| raise ValueError("Multiple row field candidates") | |
| elif _row_field is False: | |
| _row_field = None | |
| if _col_field in (None, True): | |
| if len(col_fields) == 0: | |
| if _col_field is True: | |
| raise ValueError("Cannot find column field") | |
| _col_field = None | |
| elif len(col_fields) == 1: | |
| _col_field, _col_kind = col_fields.pop() | |
| else: | |
| raise ValueError("Multiple column field candidates") | |
| elif _col_field is False: | |
| _col_field = None | |
| field_meta = dict() | |
| field_meta["row_field"] = _row_field | |
| field_meta["row_kind"] = _row_kind | |
| field_meta["col_field"] = _col_field | |
| field_meta["col_kind"] = _col_kind | |
| return field_meta | |
| idx_items = dict() | |
| for r in records: | |
| r = to_field_mapping(r) | |
| field_meta = _resolve_field_meta(r.keys()) | |
| if field_meta["row_field"] is None: | |
| row_idxr = None | |
| else: | |
| row_idxr = r.get(field_meta["row_field"], None) | |
| if row_idxr == "_def": | |
| row_idxr = None | |
| if row_idxr is not None and not isinstance(row_idxr, RowIdxr): | |
| _rowidx_kwargs = dict(rowidx_kwargs) | |
| if field_meta["row_kind"] is not None and "kind" not in _rowidx_kwargs: | |
| _rowidx_kwargs["kind"] = field_meta["row_kind"] | |
| row_idxr = rowidx(row_idxr, **_rowidx_kwargs) | |
| if field_meta["col_field"] is None: | |
| col_idxr = None | |
| else: | |
| col_idxr = r.get(field_meta["col_field"], None) | |
| if col_idxr is not None and not isinstance(col_idxr, ColIdxr): | |
| _colidx_kwargs = dict(colidx_kwargs) | |
| if field_meta["col_kind"] is not None and "kind" not in _colidx_kwargs: | |
| _colidx_kwargs["kind"] = field_meta["col_kind"] | |
| col_idxr = colidx(col_idxr, **_colidx_kwargs) | |
| if isinstance(col_idxr, str) and col_idxr == "_def": | |
| col_idxr = None | |
| item_produced = False | |
| for k, v in r.items(): | |
| if index_field is not None and k == index_field: | |
| continue | |
| if field_meta["row_field"] is not None and k == field_meta["row_field"]: | |
| continue | |
| if field_meta["col_field"] is not None and k == field_meta["col_field"]: | |
| continue | |
| if k not in idx_items: | |
| idx_items[k] = [] | |
| if row_idxr is None and col_idxr is None: | |
| idx_items[k].append(("_def", v)) | |
| else: | |
| idx_items[k].append((idx(row_idxr, col_idxr), v)) | |
| item_produced = True | |
| if not item_produced: | |
| raise ValueError(f"Record {r} has no fields to set") | |
| idx_setters = dict() | |
| for k, v in idx_items.items(): | |
| idx_setters[k] = IdxSetter(v) | |
| return idx_setters | |
| posidx = PosIdxr | |
| """Shortcut for `PosIdxr`.""" | |
| __pdoc__["posidx"] = False | |
| maskidx = MaskIdxr | |
| """Shortcut for `MaskIdxr`.""" | |
| __pdoc__["maskidx"] = False | |
| lbidx = LabelIdxr | |
| """Shortcut for `LabelIdxr`.""" | |
| __pdoc__["lbidx"] = False | |
| dtidx = DatetimeIdxr | |
| """Shortcut for `DatetimeIdxr`.""" | |
| __pdoc__["dtidx"] = False | |
| dtcidx = DTCIdxr | |
| """Shortcut for `DTCIdxr`.""" | |
| __pdoc__["dtcidx"] = False | |
| pointidx = PointIdxr | |
| """Shortcut for `PointIdxr`.""" | |
| __pdoc__["pointidx"] = False | |
| rangeidx = RangeIdxr | |
| """Shortcut for `RangeIdxr`.""" | |
| __pdoc__["rangeidx"] = False | |
| autoidx = AutoIdxr | |
| """Shortcut for `AutoIdxr`.""" | |
| __pdoc__["autoidx"] = False | |
| rowidx = RowIdxr | |
| """Shortcut for `RowIdxr`.""" | |
| __pdoc__["rowidx"] = False | |
| colidx = ColIdxr | |
| """Shortcut for `ColIdxr`.""" | |
| __pdoc__["colidx"] = False | |
| idx = Idxr | |
| """Shortcut for `Idxr`.""" | |
| __pdoc__["idx"] = False | |
| </file> | |
| <file path="base/merging.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Functions for merging arrays.""" | |
| from functools import partial | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.indexes import stack_indexes, concat_indexes, clean_index | |
| from vectorbtpro.base.reshaping import to_1d_array, to_2d_array | |
| from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.config import resolve_dict, merge_dicts, HybridConfig | |
| from vectorbtpro.utils.execution import NoResult, NoResultsException, filter_out_no_results | |
| from vectorbtpro.utils.merging import MergeFunc | |
| __all__ = [ | |
| "concat_arrays", | |
| "row_stack_arrays", | |
| "column_stack_arrays", | |
| "concat_merge", | |
| "row_stack_merge", | |
| "column_stack_merge", | |
| "imageio_merge", | |
| "mixed_merge", | |
| ] | |
| __pdoc__ = {} | |
| def concat_arrays(*arrs: tp.MaybeSequence[tp.AnyArray]) -> tp.Array1d: | |
| """Concatenate arrays.""" | |
| if len(arrs) == 1: | |
| arrs = arrs[0] | |
| arrs = list(arrs) | |
| arrs = list(map(to_1d_array, arrs)) | |
| return np.concatenate(arrs) | |
| def row_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d: | |
| """Stack arrays along rows.""" | |
| if len(arrs) == 1: | |
| arrs = arrs[0] | |
| arrs = list(arrs) | |
| arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs)) | |
| return np.concatenate(arrs, axis=0) | |
| def column_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d: | |
| """Stack arrays along columns.""" | |
| if len(arrs) == 1: | |
| arrs = arrs[0] | |
| arrs = list(arrs) | |
| arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs)) | |
| common_shape = None | |
| can_concatenate = True | |
| for arr in arrs: | |
| if common_shape is None: | |
| common_shape = arr.shape | |
| if arr.shape != common_shape: | |
| can_concatenate = False | |
| continue | |
| if not (arr.ndim == 1 or (arr.ndim == 2 and arr.shape[1] == 1)): | |
| can_concatenate = False | |
| continue | |
| if can_concatenate: | |
| return np.concatenate(arrs, axis=0).reshape((len(arrs), common_shape[0])).T | |
| return np.concatenate(arrs, axis=1) | |
| def concat_merge( | |
| *objs, | |
| keys: tp.Optional[tp.Index] = None, | |
| filter_results: bool = True, | |
| raise_no_results: bool = True, | |
| wrap: tp.Optional[bool] = None, | |
| wrapper: tp.Optional[ArrayWrapper] = None, | |
| wrap_kwargs: tp.KwargsLikeSequence = None, | |
| clean_index_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.MaybeTuple[tp.AnyArray]: | |
| """Merge multiple array-like objects through concatenation. | |
| Supports a sequence of tuples. | |
| If `wrap` is None, it will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None. | |
| If `wrap` is True, each array will be wrapped with Pandas Series and merged using `pd.concat`. | |
| Otherwise, arrays will be kept as-is and merged using `concat_arrays`. | |
| `wrap_kwargs` can be a dictionary or a list of dictionaries. | |
| If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap_reduced`. | |
| Keyword arguments `**kwargs` are passed to `pd.concat` only. | |
| !!! note | |
| All arrays are assumed to have the same type and dimensionality.""" | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| if len(objs) == 0: | |
| raise ValueError("No objects to be merged") | |
| if isinstance(objs[0], tuple): | |
| if len(objs[0]) == 1: | |
| out_tuple = ( | |
| concat_merge( | |
| list(map(lambda x: x[0], objs)), | |
| keys=keys, | |
| wrap=wrap, | |
| wrapper=wrapper, | |
| wrap_kwargs=wrap_kwargs, | |
| **kwargs, | |
| ), | |
| ) | |
| else: | |
| out_tuple = tuple( | |
| map( | |
| lambda x: concat_merge( | |
| x, | |
| keys=keys, | |
| wrap=wrap, | |
| wrapper=wrapper, | |
| wrap_kwargs=wrap_kwargs, | |
| **kwargs, | |
| ), | |
| zip(*objs), | |
| ) | |
| ) | |
| if checks.is_namedtuple(objs[0]): | |
| return type(objs[0])(*out_tuple) | |
| return type(objs[0])(out_tuple) | |
| if filter_results: | |
| try: | |
| objs, keys = filter_out_no_results(objs, keys=keys) | |
| except NoResultsException as e: | |
| if raise_no_results: | |
| raise e | |
| return NoResult | |
| if isinstance(objs[0], Wrapping): | |
| raise TypeError("Concatenating Wrapping instances is not supported") | |
| if wrap_kwargs is None: | |
| wrap_kwargs = {} | |
| if wrap is None: | |
| wrap = isinstance(objs[0], pd.Series) or wrapper is not None or keys is not None or len(wrap_kwargs) > 0 | |
| if not checks.is_complex_iterable(objs[0]): | |
| if wrap: | |
| if keys is not None and isinstance(keys[0], pd.Index): | |
| if len(keys) == 1: | |
| keys = keys[0] | |
| else: | |
| keys = concat_indexes( | |
| *keys, | |
| index_concat_method="append", | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=False, | |
| axis=0, | |
| ) | |
| wrap_kwargs = merge_dicts(dict(index=keys), wrap_kwargs) | |
| return pd.Series(objs, **wrap_kwargs) | |
| return np.asarray(objs) | |
| if isinstance(objs[0], pd.Index): | |
| objs = list(map(lambda x: x.to_series(), objs)) | |
| default_index = True | |
| if not isinstance(objs[0], pd.Series): | |
| if isinstance(objs[0], pd.DataFrame): | |
| raise ValueError("Use row stacking for concatenating DataFrames") | |
| if wrap: | |
| new_objs = [] | |
| for i, obj in enumerate(objs): | |
| _wrap_kwargs = resolve_dict(wrap_kwargs, i) | |
| if wrapper is not None: | |
| if "force_1d" not in _wrap_kwargs: | |
| _wrap_kwargs["force_1d"] = True | |
| new_objs.append(wrapper.wrap_reduced(obj, **_wrap_kwargs)) | |
| else: | |
| new_objs.append(pd.Series(obj, **_wrap_kwargs)) | |
| if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True): | |
| default_index = False | |
| objs = new_objs | |
| if not wrap: | |
| return concat_arrays(objs) | |
| if keys is not None and isinstance(keys[0], pd.Index): | |
| new_obj = pd.concat(objs, axis=0, **kwargs) | |
| if len(keys) == 1: | |
| keys = keys[0] | |
| else: | |
| keys = concat_indexes( | |
| *keys, | |
| index_concat_method="append", | |
| verify_integrity=False, | |
| axis=0, | |
| ) | |
| if default_index: | |
| new_obj.index = keys | |
| else: | |
| new_obj.index = stack_indexes((keys, new_obj.index)) | |
| else: | |
| new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs) | |
| if clean_index_kwargs is None: | |
| clean_index_kwargs = {} | |
| new_obj.index = clean_index(new_obj.index, **clean_index_kwargs) | |
| return new_obj | |
| def row_stack_merge( | |
| *objs, | |
| keys: tp.Optional[tp.Index] = None, | |
| filter_results: bool = True, | |
| raise_no_results: bool = True, | |
| wrap: tp.Union[None, str, bool] = None, | |
| wrapper: tp.Optional[ArrayWrapper] = None, | |
| wrap_kwargs: tp.KwargsLikeSequence = None, | |
| clean_index_kwargs: tp.KwargsLikeSequence = None, | |
| **kwargs, | |
| ) -> tp.MaybeTuple[tp.AnyArray]: | |
| """Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through row stacking. | |
| Supports a sequence of tuples. | |
| Argument `wrap` supports the following options: | |
| * None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None | |
| * True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions) | |
| * 'sr', 'series': each array will be wrapped with Pandas Series | |
| * 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame | |
| Without wrapping, arrays will be kept as-is and merged using `row_stack_arrays`. | |
| Argument `wrap_kwargs` can be a dictionary or a list of dictionaries. | |
| If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. | |
| Keyword arguments `**kwargs` are passed to `pd.concat` and | |
| `vectorbtpro.base.wrapping.Wrapping.row_stack` only. | |
| !!! note | |
| All arrays are assumed to have the same type and dimensionality.""" | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| if len(objs) == 0: | |
| raise ValueError("No objects to be merged") | |
| if isinstance(objs[0], tuple): | |
| if len(objs[0]) == 1: | |
| out_tuple = ( | |
| row_stack_merge( | |
| list(map(lambda x: x[0], objs)), | |
| keys=keys, | |
| wrap=wrap, | |
| wrapper=wrapper, | |
| wrap_kwargs=wrap_kwargs, | |
| **kwargs, | |
| ), | |
| ) | |
| else: | |
| out_tuple = tuple( | |
| map( | |
| lambda x: row_stack_merge( | |
| x, | |
| keys=keys, | |
| wrap=wrap, | |
| wrapper=wrapper, | |
| wrap_kwargs=wrap_kwargs, | |
| **kwargs, | |
| ), | |
| zip(*objs), | |
| ) | |
| ) | |
| if checks.is_namedtuple(objs[0]): | |
| return type(objs[0])(*out_tuple) | |
| return type(objs[0])(out_tuple) | |
| if filter_results: | |
| try: | |
| objs, keys = filter_out_no_results(objs, keys=keys) | |
| except NoResultsException as e: | |
| if raise_no_results: | |
| raise e | |
| return NoResult | |
| if isinstance(objs[0], Wrapping): | |
| kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs) | |
| return type(objs[0]).row_stack(objs, **kwargs) | |
| if wrap_kwargs is None: | |
| wrap_kwargs = {} | |
| if wrap is None: | |
| wrap = ( | |
| isinstance(objs[0], (pd.Series, pd.DataFrame)) | |
| or wrapper is not None | |
| or keys is not None | |
| or len(wrap_kwargs) > 0 | |
| ) | |
| if isinstance(objs[0], pd.Index): | |
| objs = list(map(lambda x: x.to_series(), objs)) | |
| default_index = True | |
| if not isinstance(objs[0], (pd.Series, pd.DataFrame)): | |
| if isinstance(wrap, str) or wrap: | |
| new_objs = [] | |
| for i, obj in enumerate(objs): | |
| _wrap_kwargs = resolve_dict(wrap_kwargs, i) | |
| if wrapper is not None: | |
| new_objs.append(wrapper.wrap(obj, **_wrap_kwargs)) | |
| else: | |
| if not isinstance(wrap, str): | |
| if isinstance(obj, np.ndarray): | |
| ndim = obj.ndim | |
| else: | |
| ndim = np.asarray(obj).ndim | |
| if ndim == 1: | |
| wrap = "series" | |
| else: | |
| wrap = "frame" | |
| if isinstance(wrap, str): | |
| if wrap.lower() in ("sr", "series"): | |
| new_objs.append(pd.Series(obj, **_wrap_kwargs)) | |
| elif wrap.lower() in ("df", "frame", "dataframe"): | |
| new_objs.append(pd.DataFrame(obj, **_wrap_kwargs)) | |
| else: | |
| raise ValueError(f"Invalid wrapping option: '{wrap}'") | |
| if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True): | |
| default_index = False | |
| objs = new_objs | |
| if not wrap: | |
| return row_stack_arrays(objs) | |
| if keys is not None and isinstance(keys[0], pd.Index): | |
| new_obj = pd.concat(objs, axis=0, **kwargs) | |
| if len(keys) == 1: | |
| keys = keys[0] | |
| else: | |
| keys = concat_indexes( | |
| *keys, | |
| index_concat_method="append", | |
| verify_integrity=False, | |
| axis=0, | |
| ) | |
| if default_index: | |
| new_obj.index = keys | |
| else: | |
| new_obj.index = stack_indexes((keys, new_obj.index)) | |
| else: | |
| new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs) | |
| if clean_index_kwargs is None: | |
| clean_index_kwargs = {} | |
| new_obj.index = clean_index(new_obj.index, **clean_index_kwargs) | |
| return new_obj | |
| def column_stack_merge( | |
| *objs, | |
| reset_index: tp.Union[None, bool, str] = None, | |
| fill_value: tp.Scalar = np.nan, | |
| keys: tp.Optional[tp.Index] = None, | |
| filter_results: bool = True, | |
| raise_no_results: bool = True, | |
| wrap: tp.Union[None, str, bool] = None, | |
| wrapper: tp.Optional[ArrayWrapper] = None, | |
| wrap_kwargs: tp.KwargsLikeSequence = None, | |
| clean_index_kwargs: tp.KwargsLikeSequence = None, | |
| **kwargs, | |
| ) -> tp.MaybeTuple[tp.AnyArray]: | |
| """Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through column stacking. | |
| Supports a sequence of tuples. | |
| Argument `wrap` supports the following options: | |
| * None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None | |
| * True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions) | |
| * 'sr', 'series': each array will be wrapped with Pandas Series | |
| * 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame | |
| Without wrapping, arrays will be kept as-is and merged using `column_stack_arrays`. | |
| Argument `wrap_kwargs` can be a dictionary or a list of dictionaries. | |
| If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. | |
| Keyword arguments `**kwargs` are passed to `pd.concat` and | |
| `vectorbtpro.base.wrapping.Wrapping.column_stack` only. | |
| Argument `reset_index` supports the following options: | |
| * False or None: Keep original index of each object | |
| * True or 'from_start': Reset index of each object and align them at start | |
| * 'from_end': Reset index of each object and align them at end | |
| Options above work on Pandas, NumPy, and `vectorbtpro.base.wrapping.Wrapping` instances. | |
| !!! note | |
| All arrays are assumed to have the same type and dimensionality.""" | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| if len(objs) == 0: | |
| raise ValueError("No objects to be merged") | |
| if isinstance(reset_index, bool): | |
| if reset_index: | |
| reset_index = "from_start" | |
| else: | |
| reset_index = None | |
| if isinstance(objs[0], tuple): | |
| if len(objs[0]) == 1: | |
| out_tuple = ( | |
| column_stack_merge( | |
| list(map(lambda x: x[0], objs)), | |
| reset_index=reset_index, | |
| keys=keys, | |
| wrap=wrap, | |
| wrapper=wrapper, | |
| wrap_kwargs=wrap_kwargs, | |
| **kwargs, | |
| ), | |
| ) | |
| else: | |
| out_tuple = tuple( | |
| map( | |
| lambda x: column_stack_merge( | |
| x, | |
| reset_index=reset_index, | |
| keys=keys, | |
| wrap=wrap, | |
| wrapper=wrapper, | |
| wrap_kwargs=wrap_kwargs, | |
| **kwargs, | |
| ), | |
| zip(*objs), | |
| ) | |
| ) | |
| if checks.is_namedtuple(objs[0]): | |
| return type(objs[0])(*out_tuple) | |
| return type(objs[0])(out_tuple) | |
| if filter_results: | |
| try: | |
| objs, keys = filter_out_no_results(objs, keys=keys) | |
| except NoResultsException as e: | |
| if raise_no_results: | |
| raise e | |
| return NoResult | |
| if isinstance(objs[0], Wrapping): | |
| if reset_index is not None: | |
| max_length = max(map(lambda x: x.wrapper.shape[0], objs)) | |
| new_objs = [] | |
| for obj in objs: | |
| if isinstance(reset_index, str) and reset_index.lower() == "from_start": | |
| new_index = pd.RangeIndex(stop=obj.wrapper.shape[0]) | |
| new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index)) | |
| elif isinstance(reset_index, str) and reset_index.lower() == "from_end": | |
| new_index = pd.RangeIndex(start=max_length - obj.wrapper.shape[0], stop=max_length) | |
| new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index)) | |
| else: | |
| raise ValueError(f"Invalid index resetting option: '{reset_index}'") | |
| new_objs.append(new_obj) | |
| objs = new_objs | |
| kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs) | |
| return type(objs[0]).column_stack(objs, **kwargs) | |
| if wrap_kwargs is None: | |
| wrap_kwargs = {} | |
| if wrap is None: | |
| wrap = ( | |
| isinstance(objs[0], (pd.Series, pd.DataFrame)) | |
| or wrapper is not None | |
| or keys is not None | |
| or len(wrap_kwargs) > 0 | |
| ) | |
| if isinstance(objs[0], pd.Index): | |
| objs = list(map(lambda x: x.to_series(), objs)) | |
| default_columns = True | |
| if not isinstance(objs[0], (pd.Series, pd.DataFrame)): | |
| if isinstance(wrap, str) or wrap: | |
| new_objs = [] | |
| for i, obj in enumerate(objs): | |
| _wrap_kwargs = resolve_dict(wrap_kwargs, i) | |
| if wrapper is not None: | |
| new_objs.append(wrapper.wrap(obj, **_wrap_kwargs)) | |
| else: | |
| if not isinstance(wrap, str): | |
| if isinstance(obj, np.ndarray): | |
| ndim = obj.ndim | |
| else: | |
| ndim = np.asarray(obj).ndim | |
| if ndim == 1: | |
| wrap = "series" | |
| else: | |
| wrap = "frame" | |
| if isinstance(wrap, str): | |
| if wrap.lower() in ("sr", "series"): | |
| new_objs.append(pd.Series(obj, **_wrap_kwargs)) | |
| elif wrap.lower() in ("df", "frame", "dataframe"): | |
| new_objs.append(pd.DataFrame(obj, **_wrap_kwargs)) | |
| else: | |
| raise ValueError(f"Invalid wrapping option: '{wrap}'") | |
| if ( | |
| default_columns | |
| and isinstance(new_objs[-1], pd.DataFrame) | |
| and not checks.is_default_index(new_objs[-1].columns, check_names=True) | |
| ): | |
| default_columns = False | |
| objs = new_objs | |
| if not wrap: | |
| if reset_index is not None: | |
| min_n_rows = None | |
| max_n_rows = None | |
| n_cols = 0 | |
| new_objs = [] | |
| for obj in objs: | |
| new_obj = to_2d_array(obj) | |
| new_objs.append(new_obj) | |
| if min_n_rows is None or new_obj.shape[0] < min_n_rows: | |
| min_n_rows = new_obj.shape[0] | |
| if max_n_rows is None or new_obj.shape[0] > min_n_rows: | |
| max_n_rows = new_obj.shape[0] | |
| n_cols += new_obj.shape[1] | |
| if min_n_rows == max_n_rows: | |
| return column_stack_arrays(new_objs) | |
| new_obj = np.full((max_n_rows, n_cols), fill_value) | |
| start_col = 0 | |
| for obj in new_objs: | |
| end_col = start_col + obj.shape[1] | |
| if isinstance(reset_index, str) and reset_index.lower() == "from_start": | |
| new_obj[: len(obj), start_col:end_col] = obj | |
| elif isinstance(reset_index, str) and reset_index.lower() == "from_end": | |
| new_obj[-len(obj) :, start_col:end_col] = obj | |
| else: | |
| raise ValueError(f"Invalid index resetting option: '{reset_index}'") | |
| start_col = end_col | |
| return new_obj | |
| return column_stack_arrays(objs) | |
| if reset_index is not None: | |
| max_length = max(map(len, objs)) | |
| new_objs = [] | |
| for obj in objs: | |
| new_obj = obj.copy(deep=False) | |
| if isinstance(reset_index, str) and reset_index.lower() == "from_start": | |
| new_obj.index = pd.RangeIndex(stop=len(new_obj)) | |
| elif isinstance(reset_index, str) and reset_index.lower() == "from_end": | |
| new_obj.index = pd.RangeIndex(start=max_length - len(new_obj), stop=max_length) | |
| else: | |
| raise ValueError(f"Invalid index resetting option: '{reset_index}'") | |
| new_objs.append(new_obj) | |
| objs = new_objs | |
| kwargs = merge_dicts(dict(sort=True), kwargs) | |
| if keys is not None and isinstance(keys[0], pd.Index): | |
| new_obj = pd.concat(objs, axis=1, **kwargs) | |
| if len(keys) == 1: | |
| keys = keys[0] | |
| else: | |
| keys = concat_indexes( | |
| *keys, | |
| index_concat_method="append", | |
| verify_integrity=False, | |
| axis=1, | |
| ) | |
| if default_columns: | |
| new_obj.columns = keys | |
| else: | |
| new_obj.columns = stack_indexes((keys, new_obj.columns)) | |
| else: | |
| new_obj = pd.concat(objs, axis=1, keys=keys, **kwargs) | |
| if clean_index_kwargs is None: | |
| clean_index_kwargs = {} | |
| new_obj.columns = clean_index(new_obj.columns, **clean_index_kwargs) | |
| return new_obj | |
| def imageio_merge( | |
| *objs, | |
| keys: tp.Optional[tp.Index] = None, | |
| filter_results: bool = True, | |
| raise_no_results: bool = True, | |
| to_image_kwargs: tp.KwargsLike = None, | |
| imread_kwargs: tp.KwargsLike = None, | |
| **imwrite_kwargs, | |
| ) -> tp.MaybeTuple[tp.Union[None, bytes]]: | |
| """Merge multiple figure-like objects by writing them with `imageio`. | |
| Keyword arguments `to_image_kwargs` are passed to `plotly.graph_objects.Figure.to_image`. | |
| Keyword arguments `imread_kwargs` and `**imwrite_kwargs` are passed to | |
| `imageio.imread` and `imageio.imwrite` respectively. | |
| Keys are not used in any way.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("plotly") | |
| import plotly.graph_objects as go | |
| import imageio.v3 as iio | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| if len(objs) == 0: | |
| raise ValueError("No objects to be merged") | |
| if isinstance(objs[0], tuple): | |
| if len(objs[0]) == 1: | |
| out_tuple = ( | |
| imageio_merge( | |
| list(map(lambda x: x[0], objs)), | |
| keys=keys, | |
| imread_kwargs=imread_kwargs, | |
| to_image_kwargs=to_image_kwargs, | |
| **imwrite_kwargs, | |
| ), | |
| ) | |
| else: | |
| out_tuple = tuple( | |
| map( | |
| lambda x: imageio_merge( | |
| x, | |
| keys=keys, | |
| imread_kwargs=imread_kwargs, | |
| to_image_kwargs=to_image_kwargs, | |
| **imwrite_kwargs, | |
| ), | |
| zip(*objs), | |
| ) | |
| ) | |
| if checks.is_namedtuple(objs[0]): | |
| return type(objs[0])(*out_tuple) | |
| return type(objs[0])(out_tuple) | |
| if filter_results: | |
| try: | |
| objs, keys = filter_out_no_results(objs, keys=keys) | |
| except NoResultsException as e: | |
| if raise_no_results: | |
| raise e | |
| return NoResult | |
| if imread_kwargs is None: | |
| imread_kwargs = {} | |
| if to_image_kwargs is None: | |
| to_image_kwargs = {} | |
| frames = [] | |
| for obj in objs: | |
| if obj is None: | |
| continue | |
| if isinstance(obj, (go.Figure, go.FigureWidget)): | |
| obj = obj.to_image(**to_image_kwargs) | |
| if not isinstance(obj, np.ndarray): | |
| obj = iio.imread(obj, **imread_kwargs) | |
| frames.append(obj) | |
| return iio.imwrite(image=frames, **imwrite_kwargs) | |
| def mixed_merge( | |
| *objs, | |
| merge_funcs: tp.Optional[tp.MergeFuncLike] = None, | |
| mixed_kwargs: tp.Optional[tp.Sequence[tp.KwargsLike]] = None, | |
| **kwargs, | |
| ) -> tp.MaybeTuple[tp.AnyArray]: | |
| """Merge objects of mixed types.""" | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| if len(objs) == 0: | |
| raise ValueError("No objects to be merged") | |
| if merge_funcs is None: | |
| raise ValueError("Merging functions or their names are required") | |
| if not isinstance(objs[0], tuple): | |
| raise ValueError("Mixed merging must be applied on tuples") | |
| outputs = [] | |
| for i, output_objs in enumerate(zip(*objs)): | |
| output_objs = list(output_objs) | |
| merge_func = resolve_merge_func(merge_funcs[i]) | |
| if merge_func is None: | |
| outputs.append(output_objs) | |
| else: | |
| if mixed_kwargs is None: | |
| _kwargs = kwargs | |
| else: | |
| _kwargs = merge_dicts(kwargs, mixed_kwargs[i]) | |
| output = merge_func(output_objs, **_kwargs) | |
| outputs.append(output) | |
| return tuple(outputs) | |
| merge_func_config = HybridConfig( | |
| dict( | |
| concat=concat_merge, | |
| row_stack=row_stack_merge, | |
| column_stack=column_stack_merge, | |
| reset_column_stack=partial(column_stack_merge, reset_index=True), | |
| from_start_column_stack=partial(column_stack_merge, reset_index="from_start"), | |
| from_end_column_stack=partial(column_stack_merge, reset_index="from_end"), | |
| imageio=imageio_merge, | |
| ) | |
| ) | |
| """_""" | |
| __pdoc__[ | |
| "merge_func_config" | |
| ] = f"""Config for merging functions. | |
| ```python | |
| {merge_func_config.prettify()} | |
| ``` | |
| """ | |
| def resolve_merge_func(merge_func: tp.MergeFuncLike) -> tp.Optional[tp.Callable]: | |
| """Resolve a merging function into a callable. | |
| If a string, looks up into `merge_func_config`. If a sequence, uses `mixed_merge` with | |
| `merge_funcs=merge_func`. If an instance of `vectorbtpro.utils.merging.MergeFunc`, calls | |
| `vectorbtpro.utils.merging.MergeFunc.resolve_merge_func` to get the actual callable.""" | |
| if merge_func is None: | |
| return None | |
| if isinstance(merge_func, str): | |
| if merge_func.lower() not in merge_func_config: | |
| raise ValueError(f"Invalid merging function name: '{merge_func}'") | |
| return merge_func_config[merge_func.lower()] | |
| if checks.is_sequence(merge_func): | |
| return partial(mixed_merge, merge_funcs=merge_func) | |
| if isinstance(merge_func, MergeFunc): | |
| return merge_func.resolve_merge_func() | |
| return merge_func | |
| def is_merge_func_from_config(merge_func: tp.MergeFuncLike) -> bool: | |
| """Return whether the merging function can be found in `merge_func_config`.""" | |
| if merge_func is None: | |
| return False | |
| if isinstance(merge_func, str): | |
| return merge_func.lower() in merge_func_config | |
| if checks.is_sequence(merge_func): | |
| return all(map(is_merge_func_from_config, merge_func)) | |
| if isinstance(merge_func, MergeFunc): | |
| return is_merge_func_from_config(merge_func.merge_func) | |
| return False | |
| </file> | |
| <file path="base/preparing.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Classes for preparing arguments.""" | |
| import inspect | |
| import string | |
| from collections import defaultdict | |
| from datetime import timedelta, time | |
| from functools import cached_property as cachedproperty | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.decorators import override_arg_config, attach_arg_properties | |
| from vectorbtpro.base.indexes import repeat_index | |
| from vectorbtpro.base.indexing import index_dict, IdxSetter, IdxSetterFactory, IdxRecords | |
| from vectorbtpro.base.merging import concat_arrays, column_stack_arrays | |
| from vectorbtpro.base.resampling.base import Resampler | |
| from vectorbtpro.base.reshaping import BCO, Default, Ref, broadcast | |
| from vectorbtpro.base.wrapping import ArrayWrapper | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.attr_ import get_dict_attr | |
| from vectorbtpro.utils.config import Configured | |
| from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig | |
| from vectorbtpro.utils.cutting import suggest_module_path, cut_and_save_func | |
| from vectorbtpro.utils.enum_ import map_enum_fields | |
| from vectorbtpro.utils.module_ import import_module_from_path | |
| from vectorbtpro.utils.params import Param | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| from vectorbtpro.utils.path_ import remove_dir | |
| from vectorbtpro.utils.random_ import set_seed | |
| from vectorbtpro.utils.template import CustomTemplate, RepFunc, substitute_templates | |
| __all__ = [ | |
| "BasePreparer", | |
| ] | |
| __pdoc__ = {} | |
| base_arg_config = ReadonlyConfig( | |
| dict( | |
| broadcast_named_args=dict(is_dict=True), | |
| broadcast_kwargs=dict(is_dict=True), | |
| template_context=dict(is_dict=True), | |
| seed=dict(), | |
| jitted=dict(), | |
| chunked=dict(), | |
| staticized=dict(), | |
| records=dict(), | |
| ) | |
| ) | |
| """_""" | |
| __pdoc__[ | |
| "base_arg_config" | |
| ] = f"""Argument config for `BasePreparer`. | |
| ```python | |
| {base_arg_config.prettify()} | |
| ``` | |
| """ | |
| class MetaBasePreparer(type(Configured)): | |
| """Metaclass for `BasePreparer`.""" | |
| @property | |
| def arg_config(cls) -> Config: | |
| """Argument config.""" | |
| return cls._arg_config | |
| @attach_arg_properties | |
| @override_arg_config(base_arg_config) | |
| class BasePreparer(Configured, metaclass=MetaBasePreparer): | |
| """Base class for preparing target functions and arguments. | |
| !!! warning | |
| Most properties are force-cached - create a new instance to override any attribute.""" | |
| _expected_keys_mode: tp.ExpectedKeysMode = "disable" | |
| _writeable_attrs: tp.WriteableAttrs = {"_arg_config"} | |
| _settings_path: tp.SettingsPath = None | |
| def __init__(self, arg_config: tp.KwargsLike = None, **kwargs) -> None: | |
| Configured.__init__(self, arg_config=arg_config, **kwargs) | |
| # Copy writeable attrs | |
| self._arg_config = type(self)._arg_config.copy() | |
| if arg_config is not None: | |
| self._arg_config = merge_dicts(self._arg_config, arg_config) | |
| _arg_config: tp.ClassVar[Config] = HybridConfig() | |
| @property | |
| def arg_config(self) -> Config: | |
| """Argument config of `${cls_name}`. | |
| ```python | |
| ${arg_config} | |
| ``` | |
| """ | |
| return self._arg_config | |
| @classmethod | |
| def map_enum_value(cls, value: tp.ArrayLike, look_for_type: tp.Optional[type] = None, **kwargs) -> tp.ArrayLike: | |
| """Map enumerated value(s).""" | |
| if look_for_type is not None: | |
| if isinstance(value, look_for_type): | |
| return map_enum_fields(value, **kwargs) | |
| return value | |
| if isinstance(value, (CustomTemplate, Ref)): | |
| return value | |
| if isinstance(value, (Param, BCO, Default)): | |
| attr_dct = value.asdict() | |
| if isinstance(value, Param) and attr_dct["map_template"] is None: | |
| attr_dct["map_template"] = RepFunc(lambda values: cls.map_enum_value(values, **kwargs)) | |
| elif not isinstance(value, Param): | |
| attr_dct["value"] = cls.map_enum_value(attr_dct["value"], **kwargs) | |
| return type(value)(**attr_dct) | |
| if isinstance(value, index_dict): | |
| return index_dict({k: cls.map_enum_value(v, **kwargs) for k, v in value.items()}) | |
| if isinstance(value, IdxSetterFactory): | |
| value = value.get() | |
| if not isinstance(value, IdxSetter): | |
| raise ValueError("Index setter factory must return exactly one index setter") | |
| if isinstance(value, IdxSetter): | |
| return IdxSetter([(k, cls.map_enum_value(v, **kwargs)) for k, v in value.idx_items]) | |
| return map_enum_fields(value, **kwargs) | |
| @classmethod | |
| def prepare_td_obj(cls, td_obj: object, old_as_keys: bool = True) -> object: | |
| """Prepare a timedelta object for broadcasting.""" | |
| if isinstance(td_obj, Param): | |
| return td_obj.map_value(cls.prepare_td_obj, old_as_keys=old_as_keys) | |
| if isinstance(td_obj, (str, timedelta, pd.DateOffset, pd.Timedelta)): | |
| td_obj = dt.to_timedelta64(td_obj) | |
| elif isinstance(td_obj, pd.Index): | |
| td_obj = td_obj.values | |
| return td_obj | |
| @classmethod | |
| def prepare_dt_obj( | |
| cls, | |
| dt_obj: object, | |
| old_as_keys: bool = True, | |
| last_before: tp.Optional[bool] = None, | |
| ) -> object: | |
| """Prepare a datetime object for broadcasting.""" | |
| if isinstance(dt_obj, Param): | |
| return dt_obj.map_value(cls.prepare_dt_obj, old_as_keys=old_as_keys) | |
| if isinstance(dt_obj, (str, time, timedelta, pd.DateOffset, pd.Timedelta)): | |
| def _apply_last_before(source_index, target_index, source_freq): | |
| resampler = Resampler(source_index, target_index, source_freq=source_freq) | |
| last_indices = resampler.last_before_target_index(incl_source=False) | |
| source_rbound_ns = resampler.source_rbound_index.vbt.to_ns() | |
| return np.where(last_indices != -1, source_rbound_ns[last_indices], -1) | |
| def _to_dt(wrapper, _dt_obj=dt_obj, _last_before=last_before): | |
| if _last_before is None: | |
| _last_before = False | |
| _dt_obj = dt.try_align_dt_to_index(_dt_obj, wrapper.index) | |
| source_index = wrapper.index[wrapper.index < _dt_obj] | |
| target_index = repeat_index(pd.Index([_dt_obj]), len(source_index)) | |
| if _last_before: | |
| target_ns = _apply_last_before(source_index, target_index, wrapper.freq) | |
| else: | |
| target_ns = target_index.vbt.to_ns() | |
| if len(target_ns) < len(wrapper.index): | |
| target_ns = concat_arrays((target_ns, np.full(len(wrapper.index) - len(target_ns), -1))) | |
| return target_ns | |
| def _to_td(wrapper, _dt_obj=dt_obj, _last_before=last_before): | |
| if _last_before is None: | |
| _last_before = True | |
| target_index = wrapper.index.vbt.to_period_ts(dt.to_freq(_dt_obj), shift=True) | |
| if _last_before: | |
| return _apply_last_before(wrapper.index, target_index, wrapper.freq) | |
| return target_index.vbt.to_ns() | |
| def _to_time(wrapper, _dt_obj=dt_obj, _last_before=last_before): | |
| if _last_before is None: | |
| _last_before = False | |
| floor_index = wrapper.index.floor("1d") + dt.time_to_timedelta(_dt_obj) | |
| target_index = floor_index.where(wrapper.index < floor_index, floor_index + pd.Timedelta(days=1)) | |
| if _last_before: | |
| return _apply_last_before(wrapper.index, target_index, wrapper.freq) | |
| return target_index.vbt.to_ns() | |
| dt_obj_dt_template = RepFunc(_to_dt) | |
| dt_obj_td_template = RepFunc(_to_td) | |
| dt_obj_time_template = RepFunc(_to_time) | |
| if isinstance(dt_obj, str): | |
| try: | |
| time.fromisoformat(dt_obj) | |
| dt_obj = dt_obj_time_template | |
| except Exception as e: | |
| try: | |
| dt.to_freq(dt_obj) | |
| dt_obj = dt_obj_td_template | |
| except Exception as e: | |
| dt_obj = dt_obj_dt_template | |
| elif isinstance(dt_obj, time): | |
| dt_obj = dt_obj_time_template | |
| else: | |
| dt_obj = dt_obj_td_template | |
| elif isinstance(dt_obj, pd.Index): | |
| dt_obj = dt_obj.values | |
| return dt_obj | |
| def get_raw_arg_default(self, arg_name: str, is_dict: bool = False) -> tp.Any: | |
| """Get raw argument default.""" | |
| if self._settings_path is None: | |
| if is_dict: | |
| return {} | |
| return None | |
| value = self.get_setting(arg_name) | |
| if is_dict and value is None: | |
| return {} | |
| return value | |
| def get_raw_arg(self, arg_name: str, is_dict: bool = False, has_default: bool = True) -> tp.Any: | |
| """Get raw argument.""" | |
| value = self.config.get(arg_name, None) | |
| if is_dict: | |
| if has_default: | |
| return merge_dicts(self.get_raw_arg_default(arg_name), value) | |
| if value is None: | |
| return {} | |
| return value | |
| if value is None and has_default: | |
| return self.get_raw_arg_default(arg_name) | |
| return value | |
| @cachedproperty | |
| def idx_setters(self) -> tp.Optional[tp.Dict[tp.Label, IdxSetter]]: | |
| """Index setters from resolving the argument `records`.""" | |
| arg_config = self.arg_config["records"] | |
| records = self.get_raw_arg( | |
| "records", | |
| is_dict=arg_config.get("is_dict", False), | |
| has_default=arg_config.get("has_default", True), | |
| ) | |
| if records is None: | |
| return None | |
| if not isinstance(records, IdxRecords): | |
| records = IdxRecords(records) | |
| idx_setters = records.get() | |
| for k in idx_setters: | |
| if k in self.arg_config and not self.arg_config[k].get("broadcast", False): | |
| raise ValueError(f"Field {k} is not broadcastable and cannot be included in records") | |
| rename_fields = arg_config.get("rename_fields", {}) | |
| new_idx_setters = {} | |
| for k, v in idx_setters.items(): | |
| if k in rename_fields: | |
| k = rename_fields[k] | |
| new_idx_setters[k] = v | |
| return new_idx_setters | |
| def get_arg_default(self, arg_name: str) -> tp.Any: | |
| """Get argument default according to the argument config.""" | |
| arg_config = self.arg_config[arg_name] | |
| arg = self.get_raw_arg_default( | |
| arg_name, | |
| is_dict=arg_config.get("is_dict", False), | |
| ) | |
| if arg is not None: | |
| if len(arg_config.get("map_enum_kwargs", {})) > 0: | |
| arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"]) | |
| if arg_config.get("is_td", False): | |
| arg = self.prepare_td_obj( | |
| arg, | |
| old_as_keys=arg_config.get("old_as_keys", True), | |
| ) | |
| if arg_config.get("is_dt", False): | |
| arg = self.prepare_dt_obj( | |
| arg, | |
| old_as_keys=arg_config.get("old_as_keys", True), | |
| last_before=arg_config.get("last_before", None), | |
| ) | |
| return arg | |
| def get_arg(self, arg_name: str, use_idx_setter: bool = True, use_default: bool = True) -> tp.Any: | |
| """Get mapped argument according to the argument config.""" | |
| arg_config = self.arg_config[arg_name] | |
| if use_idx_setter and self.idx_setters is not None and arg_name in self.idx_setters: | |
| arg = self.idx_setters[arg_name] | |
| else: | |
| arg = self.get_raw_arg( | |
| arg_name, | |
| is_dict=arg_config.get("is_dict", False), | |
| has_default=arg_config.get("has_default", True) if use_default else False, | |
| ) | |
| if arg is not None: | |
| if len(arg_config.get("map_enum_kwargs", {})) > 0: | |
| arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"]) | |
| if arg_config.get("is_td", False): | |
| arg = self.prepare_td_obj(arg) | |
| if arg_config.get("is_dt", False): | |
| arg = self.prepare_dt_obj(arg, last_before=arg_config.get("last_before", None)) | |
| return arg | |
| def __getitem__(self, arg_name) -> tp.Any: | |
| return self.get_arg(arg_name) | |
| def __iter__(self): | |
| raise TypeError(f"'{type(self).__name__}' object is not iterable") | |
| @classmethod | |
| def prepare_td_arr(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike: | |
| """Prepare a timedelta array.""" | |
| if td_arr.dtype == object: | |
| if td_arr.ndim in (0, 1): | |
| td_arr = pd.to_timedelta(td_arr) | |
| if isinstance(td_arr, pd.Timedelta): | |
| td_arr = td_arr.to_timedelta64() | |
| else: | |
| td_arr = td_arr.values | |
| else: | |
| td_arr_cols = [] | |
| for col in range(td_arr.shape[1]): | |
| td_arr_col = pd.to_timedelta(td_arr[:, col]) | |
| td_arr_cols.append(td_arr_col.values) | |
| td_arr = column_stack_arrays(td_arr_cols) | |
| return td_arr | |
| @classmethod | |
| def prepare_dt_arr(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike: | |
| """Prepare a datetime array.""" | |
| if dt_arr.dtype == object: | |
| if dt_arr.ndim in (0, 1): | |
| dt_arr = pd.to_datetime(dt_arr).tz_localize(None) | |
| if isinstance(dt_arr, pd.Timestamp): | |
| dt_arr = dt_arr.to_datetime64() | |
| else: | |
| dt_arr = dt_arr.values | |
| else: | |
| dt_arr_cols = [] | |
| for col in range(dt_arr.shape[1]): | |
| dt_arr_col = pd.to_datetime(dt_arr[:, col]).tz_localize(None) | |
| dt_arr_cols.append(dt_arr_col.values) | |
| dt_arr = column_stack_arrays(dt_arr_cols) | |
| return dt_arr | |
| @classmethod | |
| def td_arr_to_ns(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike: | |
| """Prepare a timedelta array and convert it to nanoseconds.""" | |
| return dt.to_ns(cls.prepare_td_arr(td_arr)) | |
| @classmethod | |
| def dt_arr_to_ns(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike: | |
| """Prepare a datetime array and convert it to nanoseconds.""" | |
| return dt.to_ns(cls.prepare_dt_arr(dt_arr)) | |
| def prepare_post_arg(self, arg_name: str, value: tp.Optional[tp.ArrayLike] = None) -> object: | |
| """Prepare an argument after broadcasting and/or template substitution.""" | |
| if value is None: | |
| if arg_name in self.post_args: | |
| arg = self.post_args[arg_name] | |
| else: | |
| arg = getattr(self, "_pre_" + arg_name) | |
| else: | |
| arg = value | |
| if arg is not None: | |
| arg_config = self.arg_config[arg_name] | |
| if arg_config.get("substitute_templates", False): | |
| arg = substitute_templates(arg, self.template_context, eval_id=arg_name) | |
| if "map_enum_kwargs" in arg_config: | |
| arg = map_enum_fields(arg, **arg_config["map_enum_kwargs"]) | |
| if arg_config.get("is_td", False): | |
| arg = self.td_arr_to_ns(arg) | |
| if arg_config.get("is_dt", False): | |
| arg = self.dt_arr_to_ns(arg) | |
| if "type" in arg_config: | |
| checks.assert_instance_of(arg, arg_config["type"], arg_name=arg_name) | |
| if "subdtype" in arg_config: | |
| checks.assert_subdtype(arg, arg_config["subdtype"], arg_name=arg_name) | |
| return arg | |
| @classmethod | |
| def adapt_staticized_to_udf(cls, staticized: tp.Kwargs, func: tp.Union[str, tp.Callable], func_name: str) -> None: | |
| """Adapt `staticized` dictionary to a UDF.""" | |
| target_func_module = inspect.getmodule(staticized["func"]) | |
| if isinstance(func, tuple): | |
| func, actual_func_name = func | |
| else: | |
| actual_func_name = None | |
| if isinstance(func, (str, Path)): | |
| if actual_func_name is None: | |
| actual_func_name = func_name | |
| if isinstance(func, str) and not func.endswith(".py") and hasattr(target_func_module, func): | |
| staticized[f"{func_name}_block"] = func | |
| return None | |
| func = Path(func) | |
| module_path = func.resolve() | |
| else: | |
| if actual_func_name is None: | |
| actual_func_name = func.__name__ | |
| if inspect.getmodule(func) == target_func_module: | |
| staticized[f"{func_name}_block"] = actual_func_name | |
| return None | |
| module = inspect.getmodule(func) | |
| if not hasattr(module, "__file__"): | |
| raise TypeError(f"{func_name} must be defined in a Python file") | |
| module_path = Path(module.__file__).resolve() | |
| if "import_lines" not in staticized: | |
| staticized["import_lines"] = [] | |
| reload = staticized.get("reload", False) | |
| staticized["import_lines"].extend( | |
| [ | |
| f'{func_name}_path = r"{module_path}"', | |
| f"globals().update(vbt.import_module_from_path({func_name}_path).__dict__, reload={reload})", | |
| ] | |
| ) | |
| if actual_func_name != func_name: | |
| staticized["import_lines"].append(f"{func_name} = {actual_func_name}") | |
| @classmethod | |
| def find_target_func(cls, target_func_name: str) -> tp.Callable: | |
| """Find target function by its name.""" | |
| raise NotImplementedError | |
| @classmethod | |
| def resolve_dynamic_target_func(cls, target_func_name: str, staticized: tp.KwargsLike) -> tp.Callable: | |
| """Resolve a dynamic target function.""" | |
| if staticized is None: | |
| func = cls.find_target_func(target_func_name) | |
| else: | |
| if isinstance(staticized, dict): | |
| staticized = dict(staticized) | |
| module_path = suggest_module_path( | |
| staticized.get("suggest_fname", target_func_name), | |
| path=staticized.pop("path", None), | |
| mkdir_kwargs=staticized.get("mkdir_kwargs", None), | |
| ) | |
| if "new_func_name" not in staticized: | |
| staticized["new_func_name"] = target_func_name | |
| if staticized.pop("override", False) or not module_path.exists(): | |
| if "skip_func" not in staticized: | |
| def _skip_func(out_lines, func_name): | |
| to_skip = lambda x: f"def {func_name}" in x or x.startswith(f"{func_name}_path =") | |
| return any(map(to_skip, out_lines)) | |
| staticized["skip_func"] = _skip_func | |
| module_path = cut_and_save_func(path=module_path, **staticized) | |
| if staticized.get("clear_cache", True): | |
| remove_dir(module_path.parent / "__pycache__", with_contents=True, missing_ok=True) | |
| reload = staticized.pop("reload", False) | |
| module = import_module_from_path(module_path, reload=reload) | |
| func = getattr(module, staticized["new_func_name"]) | |
| else: | |
| func = staticized | |
| return func | |
| def set_seed(self) -> None: | |
| """Set seed.""" | |
| seed = self.seed | |
| if seed is not None: | |
| set_seed(seed) | |
| # ############# Before broadcasting ############# # | |
| @cachedproperty | |
| def _pre_template_context(self) -> tp.Kwargs: | |
| """Argument `template_context` before broadcasting.""" | |
| return merge_dicts(dict(preparer=self), self["template_context"]) | |
| # ############# Broadcasting ############# # | |
| @cachedproperty | |
| def pre_args(self) -> tp.Kwargs: | |
| """Arguments before broadcasting.""" | |
| pre_args = dict() | |
| for k, v in self.arg_config.items(): | |
| if v.get("broadcast", False): | |
| pre_args[k] = getattr(self, "_pre_" + k) | |
| return pre_args | |
| @cachedproperty | |
| def args_to_broadcast(self) -> dict: | |
| """Arguments to broadcast.""" | |
| return merge_dicts(self.idx_setters, self.pre_args, self.broadcast_named_args) | |
| @cachedproperty | |
| def def_broadcast_kwargs(self) -> tp.Kwargs: | |
| """Default keyword arguments for broadcasting.""" | |
| return dict( | |
| to_pd=False, | |
| keep_flex=dict(cash_earnings=self.keep_inout_flex, _def=True), | |
| wrapper_kwargs=dict( | |
| freq=self._pre_freq, | |
| group_by=self.group_by, | |
| ), | |
| return_wrapper=True, | |
| template_context=self._pre_template_context, | |
| ) | |
| @cachedproperty | |
| def broadcast_kwargs(self) -> tp.Kwargs: | |
| """Argument `broadcast_kwargs`.""" | |
| arg_broadcast_kwargs = defaultdict(dict) | |
| for k, v in self.arg_config.items(): | |
| if v.get("broadcast", False): | |
| broadcast_kwargs = v.get("broadcast_kwargs", None) | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| for k2, v2 in broadcast_kwargs.items(): | |
| arg_broadcast_kwargs[k2][k] = v2 | |
| for k in self.args_to_broadcast: | |
| new_fill_value = None | |
| if k in self.pre_args: | |
| fill_default = self.arg_config[k].get("fill_default", True) | |
| if self.idx_setters is not None and k in self.idx_setters: | |
| new_fill_value = self.get_arg(k, use_idx_setter=False, use_default=fill_default) | |
| elif fill_default and self.arg_config[k].get("has_default", True): | |
| new_fill_value = self.get_arg_default(k) | |
| elif k in self.broadcast_named_args: | |
| if self.idx_setters is not None and k in self.idx_setters: | |
| new_fill_value = self.broadcast_named_args[k] | |
| if new_fill_value is not None: | |
| if not np.isscalar(new_fill_value): | |
| raise TypeError(f"Argument '{k}' (and its default) must be a scalar when also provided via records") | |
| if "reindex_kwargs" not in arg_broadcast_kwargs: | |
| arg_broadcast_kwargs["reindex_kwargs"] = {} | |
| if k not in arg_broadcast_kwargs["reindex_kwargs"]: | |
| arg_broadcast_kwargs["reindex_kwargs"][k] = {} | |
| arg_broadcast_kwargs["reindex_kwargs"][k]["fill_value"] = new_fill_value | |
| return merge_dicts( | |
| self.def_broadcast_kwargs, | |
| dict(arg_broadcast_kwargs), | |
| self["broadcast_kwargs"], | |
| ) | |
| @cachedproperty | |
| def broadcast_result(self) -> tp.Any: | |
| """Result of broadcasting.""" | |
| return broadcast(self.args_to_broadcast, **self.broadcast_kwargs) | |
| @cachedproperty | |
| def post_args(self) -> tp.Kwargs: | |
| """Arguments after broadcasting.""" | |
| return self.broadcast_result[0] | |
| @cachedproperty | |
| def post_broadcast_named_args(self) -> tp.Kwargs: | |
| """Custom arguments after broadcasting.""" | |
| if self.broadcast_named_args is None: | |
| return dict() | |
| post_broadcast_named_args = dict() | |
| for k, v in self.post_args.items(): | |
| if k in self.broadcast_named_args: | |
| post_broadcast_named_args[k] = v | |
| elif self.idx_setters is not None and k in self.idx_setters and k not in self.pre_args: | |
| post_broadcast_named_args[k] = v | |
| return post_broadcast_named_args | |
| @cachedproperty | |
| def wrapper(self) -> ArrayWrapper: | |
| """Array wrapper.""" | |
| return self.broadcast_result[1] | |
| @cachedproperty | |
| def target_shape(self) -> tp.Shape: | |
| """Target shape.""" | |
| return self.wrapper.shape_2d | |
| @cachedproperty | |
| def index(self) -> tp.Array1d: | |
| """Index in nanosecond format.""" | |
| return self.wrapper.ns_index | |
| @cachedproperty | |
| def freq(self) -> int: | |
| """Frequency in nanosecond format.""" | |
| return self.wrapper.ns_freq | |
| # ############# Template substitution ############# # | |
| @cachedproperty | |
| def template_context(self) -> tp.Kwargs: | |
| """Argument `template_context`.""" | |
| builtin_args = {} | |
| for k, v in self.arg_config.items(): | |
| if v.get("broadcast", False): | |
| builtin_args[k] = getattr(self, k) | |
| return merge_dicts( | |
| dict( | |
| wrapper=self.wrapper, | |
| target_shape=self.target_shape, | |
| index=self.index, | |
| freq=self.freq, | |
| ), | |
| builtin_args, | |
| self.post_broadcast_named_args, | |
| self._pre_template_context, | |
| ) | |
| # ############# Result ############# # | |
| @cachedproperty | |
| def target_func(self) -> tp.Optional[tp.Callable]: | |
| """Target function.""" | |
| return None | |
| @cachedproperty | |
| def target_arg_map(self) -> tp.Kwargs: | |
| """Map of the target arguments to the preparer attributes.""" | |
| return dict() | |
| @cachedproperty | |
| def target_args(self) -> tp.Optional[tp.Kwargs]: | |
| """Arguments to be passed to the target function.""" | |
| if self.target_func is not None: | |
| target_arg_map = self.target_arg_map | |
| func_arg_names = get_func_arg_names(self.target_func) | |
| target_args = {} | |
| for k in func_arg_names: | |
| arg_attr = target_arg_map.get(k, k) | |
| if arg_attr is not None and hasattr(self, arg_attr): | |
| target_args[k] = getattr(self, arg_attr) | |
| return target_args | |
| return None | |
| # ############# Docs ############# # | |
| @classmethod | |
| def build_arg_config_doc(cls, source_cls: tp.Optional[type] = None) -> str: | |
| """Build argument config documentation.""" | |
| if source_cls is None: | |
| source_cls = BasePreparer | |
| return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "arg_config").__doc__)).substitute( | |
| {"arg_config": cls.arg_config.prettify(), "cls_name": cls.__name__}, | |
| ) | |
| @classmethod | |
| def override_arg_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: | |
| """Call this method on each subclass that overrides `BasePreparer.arg_config`.""" | |
| __pdoc__[cls.__name__ + ".arg_config"] = cls.build_arg_config_doc(source_cls=source_cls) | |
| BasePreparer.override_arg_config_doc(__pdoc__) | |
| </file> | |
| <file path="base/reshaping.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Functions for reshaping arrays. | |
| Reshape functions transform a Pandas object/NumPy array in some way.""" | |
| import functools | |
| import itertools | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base import indexes, wrapping, indexing | |
| from vectorbtpro.registries.jit_registry import register_jitted | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.attr_ import DefineMixin, define | |
| from vectorbtpro.utils.config import resolve_dict, merge_dicts | |
| from vectorbtpro.utils.params import combine_params, Param | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| from vectorbtpro.utils.template import CustomTemplate | |
| __all__ = [ | |
| "to_1d_shape", | |
| "to_2d_shape", | |
| "repeat_shape", | |
| "tile_shape", | |
| "to_1d_array", | |
| "to_2d_array", | |
| "to_2d_pr_array", | |
| "to_2d_pc_array", | |
| "to_1d_array_nb", | |
| "to_2d_array_nb", | |
| "to_2d_pr_array_nb", | |
| "to_2d_pc_array_nb", | |
| "broadcast_shapes", | |
| "broadcast_array_to", | |
| "broadcast_arrays", | |
| "repeat", | |
| "tile", | |
| "align_pd_arrays", | |
| "BCO", | |
| "Default", | |
| "Ref", | |
| "broadcast", | |
| "broadcast_to", | |
| ] | |
| def to_tuple_shape(shape: tp.ShapeLike) -> tp.Shape: | |
| """Convert a shape-like object to a tuple.""" | |
| if checks.is_int(shape): | |
| return (int(shape),) | |
| return tuple(shape) | |
| def to_1d_shape(shape: tp.ShapeLike) -> tp.Shape: | |
| """Convert a shape-like object to a 1-dim shape.""" | |
| shape = to_tuple_shape(shape) | |
| if len(shape) == 0: | |
| return (1,) | |
| if len(shape) == 1: | |
| return shape | |
| if len(shape) == 2 and shape[1] == 1: | |
| return (shape[0],) | |
| raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 1 dimension") | |
| def to_2d_shape(shape: tp.ShapeLike, expand_axis: int = 1) -> tp.Shape: | |
| """Convert a shape-like object to a 2-dim shape.""" | |
| shape = to_tuple_shape(shape) | |
| if len(shape) == 0: | |
| return 1, 1 | |
| if len(shape) == 1: | |
| if expand_axis == 1: | |
| return shape[0], 1 | |
| else: | |
| return shape[0], 0 | |
| if len(shape) == 2: | |
| return shape | |
| raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 2 dimensions") | |
| def repeat_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape: | |
| """Repeat shape `n` times along the specified axis.""" | |
| shape = to_tuple_shape(shape) | |
| if len(shape) <= axis: | |
| shape = tuple([shape[i] if i < len(shape) else 1 for i in range(axis + 1)]) | |
| return *shape[:axis], shape[axis] * n, *shape[axis + 1 :] | |
| def tile_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape: | |
| """Tile shape `n` times along the specified axis. | |
| Identical to `repeat_shape`. Exists purely for naming consistency.""" | |
| return repeat_shape(shape, n, axis=axis) | |
| def index_to_series(obj: tp.Index, reset_index: bool = False) -> tp.Series: | |
| """Convert Index to Series.""" | |
| if reset_index: | |
| return obj.to_series(index=pd.RangeIndex(stop=len(obj))) | |
| return obj.to_series() | |
| def index_to_frame(obj: tp.Index, reset_index: bool = False) -> tp.Frame: | |
| """Convert Index to DataFrame.""" | |
| if not isinstance(obj, pd.MultiIndex): | |
| return index_to_series(obj, reset_index=reset_index).to_frame() | |
| return obj.to_frame(index=not reset_index) | |
| def mapping_to_series(obj: tp.MappingLike) -> tp.Series: | |
| """Convert a mapping-like object to Series.""" | |
| if checks.is_namedtuple(obj): | |
| obj = obj._asdict() | |
| return pd.Series(obj) | |
| def to_any_array(obj: tp.ArrayLike, raw: bool = False, convert_index: bool = True) -> tp.AnyArray: | |
| """Convert any array-like object to an array. | |
| Pandas objects are kept as-is unless `raw` is True.""" | |
| from vectorbtpro.indicators.factory import IndicatorBase | |
| if isinstance(obj, IndicatorBase): | |
| obj = obj.main_output | |
| if not raw: | |
| if checks.is_any_array(obj): | |
| if convert_index and checks.is_index(obj): | |
| return index_to_series(obj) | |
| return obj | |
| if checks.is_mapping_like(obj): | |
| return mapping_to_series(obj) | |
| return np.asarray(obj) | |
| def to_pd_array(obj: tp.ArrayLike, convert_index: bool = True) -> tp.PandasArray: | |
| """Convert any array-like object to a Pandas object.""" | |
| from vectorbtpro.indicators.factory import IndicatorBase | |
| if isinstance(obj, IndicatorBase): | |
| obj = obj.main_output | |
| if checks.is_pandas(obj): | |
| if convert_index and checks.is_index(obj): | |
| return index_to_series(obj) | |
| return obj | |
| if checks.is_mapping_like(obj): | |
| return mapping_to_series(obj) | |
| obj = np.asarray(obj) | |
| if obj.ndim == 0: | |
| obj = obj[None] | |
| if obj.ndim == 1: | |
| return pd.Series(obj) | |
| if obj.ndim == 2: | |
| return pd.DataFrame(obj) | |
| raise ValueError("Wrong number of dimensions: cannot convert to Series or DataFrame") | |
| def soft_to_ndim(obj: tp.ArrayLike, ndim: int, raw: bool = False) -> tp.AnyArray: | |
| """Try to softly bring `obj` to the specified number of dimensions `ndim` (max 2).""" | |
| obj = to_any_array(obj, raw=raw) | |
| if ndim == 1: | |
| if obj.ndim == 2: | |
| if obj.shape[1] == 1: | |
| if checks.is_frame(obj): | |
| return obj.iloc[:, 0] | |
| return obj[:, 0] # downgrade | |
| if ndim == 2: | |
| if obj.ndim == 1: | |
| if checks.is_series(obj): | |
| return obj.to_frame() | |
| return obj[:, None] # upgrade | |
| return obj # do nothing | |
| def to_1d(obj: tp.ArrayLike, raw: bool = False) -> tp.AnyArray1d: | |
| """Reshape argument to one dimension. | |
| If `raw` is True, returns NumPy array. | |
| If 2-dim, will collapse along axis 1 (i.e., DataFrame with one column to Series).""" | |
| obj = to_any_array(obj, raw=raw) | |
| if obj.ndim == 2: | |
| if obj.shape[1] == 1: | |
| if checks.is_frame(obj): | |
| return obj.iloc[:, 0] | |
| return obj[:, 0] | |
| if obj.ndim == 1: | |
| return obj | |
| elif obj.ndim == 0: | |
| return obj.reshape((1,)) | |
| raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 1 dimension") | |
| to_1d_array = functools.partial(to_1d, raw=True) | |
| """`to_1d` with `raw` enabled.""" | |
| def to_2d(obj: tp.ArrayLike, raw: bool = False, expand_axis: int = 1) -> tp.AnyArray2d: | |
| """Reshape argument to two dimensions. | |
| If `raw` is True, returns NumPy array. | |
| If 1-dim, will expand along axis 1 (i.e., Series to DataFrame with one column).""" | |
| obj = to_any_array(obj, raw=raw) | |
| if obj.ndim == 2: | |
| return obj | |
| elif obj.ndim == 1: | |
| if checks.is_series(obj): | |
| if expand_axis == 0: | |
| return pd.DataFrame(obj.values[None, :], columns=obj.index) | |
| elif expand_axis == 1: | |
| return obj.to_frame() | |
| return np.expand_dims(obj, expand_axis) | |
| elif obj.ndim == 0: | |
| return obj.reshape((1, 1)) | |
| raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 2 dimensions") | |
| to_2d_array = functools.partial(to_2d, raw=True) | |
| """`to_2d` with `raw` enabled.""" | |
| to_2d_pr_array = functools.partial(to_2d_array, expand_axis=1) | |
| """`to_2d_array` with `expand_axis=1`.""" | |
| to_2d_pc_array = functools.partial(to_2d_array, expand_axis=0) | |
| """`to_2d_array` with `expand_axis=0`.""" | |
| @register_jitted(cache=True) | |
| def to_1d_array_nb(obj: tp.Array) -> tp.Array1d: | |
| """Resize array to one dimension.""" | |
| if obj.ndim == 0: | |
| return np.expand_dims(obj, axis=0) | |
| if obj.ndim == 1: | |
| return obj | |
| if obj.ndim == 2 and obj.shape[1] == 1: | |
| return obj[:, 0] | |
| raise ValueError("Array cannot be resized to one dimension") | |
| @register_jitted(cache=True) | |
| def to_2d_array_nb(obj: tp.Array, expand_axis: int = 1) -> tp.Array2d: | |
| """Resize array to two dimensions.""" | |
| if obj.ndim == 0: | |
| return np.expand_dims(np.expand_dims(obj, axis=0), axis=0) | |
| if obj.ndim == 1: | |
| return np.expand_dims(obj, axis=expand_axis) | |
| if obj.ndim == 2: | |
| return obj | |
| raise ValueError("Array cannot be resized to two dimensions") | |
| @register_jitted(cache=True) | |
| def to_2d_pr_array_nb(obj: tp.Array) -> tp.Array2d: | |
| """`to_2d_array_nb` with `expand_axis=1`.""" | |
| return to_2d_array_nb(obj, expand_axis=1) | |
| @register_jitted(cache=True) | |
| def to_2d_pc_array_nb(obj: tp.Array) -> tp.Array2d: | |
| """`to_2d_array_nb` with `expand_axis=0`.""" | |
| return to_2d_array_nb(obj, expand_axis=0) | |
| def to_dict(obj: tp.ArrayLike, orient: str = "dict") -> dict: | |
| """Convert object to dict.""" | |
| obj = to_pd_array(obj) | |
| if orient == "index_series": | |
| return {obj.index[i]: obj.iloc[i] for i in range(len(obj.index))} | |
| return obj.to_dict(orient) | |
| def repeat( | |
| obj: tp.ArrayLike, | |
| n: int, | |
| axis: int = 1, | |
| raw: bool = False, | |
| ignore_ranges: tp.Optional[bool] = None, | |
| ) -> tp.AnyArray: | |
| """Repeat `obj` `n` times along the specified axis.""" | |
| obj = to_any_array(obj, raw=raw) | |
| if axis == 0: | |
| if checks.is_pandas(obj): | |
| new_index = indexes.repeat_index(obj.index, n, ignore_ranges=ignore_ranges) | |
| return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=0), index=new_index) | |
| return np.repeat(obj, n, axis=0) | |
| elif axis == 1: | |
| obj = to_2d(obj) | |
| if checks.is_pandas(obj): | |
| new_columns = indexes.repeat_index(obj.columns, n, ignore_ranges=ignore_ranges) | |
| return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=1), columns=new_columns) | |
| return np.repeat(obj, n, axis=1) | |
| else: | |
| raise ValueError(f"Only axes 0 and 1 are supported, not {axis}") | |
| def tile( | |
| obj: tp.ArrayLike, | |
| n: int, | |
| axis: int = 1, | |
| raw: bool = False, | |
| ignore_ranges: tp.Optional[bool] = None, | |
| ) -> tp.AnyArray: | |
| """Tile `obj` `n` times along the specified axis.""" | |
| obj = to_any_array(obj, raw=raw) | |
| if axis == 0: | |
| if obj.ndim == 2: | |
| if checks.is_pandas(obj): | |
| new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges) | |
| return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (n, 1)), index=new_index) | |
| return np.tile(obj, (n, 1)) | |
| if checks.is_pandas(obj): | |
| new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges) | |
| return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, n), index=new_index) | |
| return np.tile(obj, n) | |
| elif axis == 1: | |
| obj = to_2d(obj) | |
| if checks.is_pandas(obj): | |
| new_columns = indexes.tile_index(obj.columns, n, ignore_ranges=ignore_ranges) | |
| return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (1, n)), columns=new_columns) | |
| return np.tile(obj, (1, n)) | |
| else: | |
| raise ValueError(f"Only axes 0 and 1 are supported, not {axis}") | |
| def broadcast_shapes( | |
| *shapes: tp.ArrayLike, | |
| axis: tp.Optional[tp.MaybeSequence[int]] = None, | |
| expand_axis: tp.Optional[tp.MaybeSequence[int]] = None, | |
| ) -> tp.Tuple[tp.Shape, ...]: | |
| """Broadcast shape-like objects using vectorbt's broadcasting rules.""" | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if expand_axis is None: | |
| expand_axis = broadcasting_cfg["expand_axis"] | |
| is_2d = False | |
| for i, shape in enumerate(shapes): | |
| shape = to_tuple_shape(shape) | |
| if len(shape) == 2: | |
| is_2d = True | |
| break | |
| new_shapes = [] | |
| for i, shape in enumerate(shapes): | |
| shape = to_tuple_shape(shape) | |
| if is_2d: | |
| if checks.is_sequence(expand_axis): | |
| _expand_axis = expand_axis[i] | |
| else: | |
| _expand_axis = expand_axis | |
| new_shape = to_2d_shape(shape, expand_axis=_expand_axis) | |
| else: | |
| new_shape = to_1d_shape(shape) | |
| if axis is not None: | |
| if checks.is_sequence(axis): | |
| _axis = axis[i] | |
| else: | |
| _axis = axis | |
| if _axis is not None: | |
| if _axis == 0: | |
| if is_2d: | |
| new_shape = (new_shape[0], 1) | |
| else: | |
| new_shape = (new_shape[0],) | |
| elif _axis == 1: | |
| if is_2d: | |
| new_shape = (1, new_shape[1]) | |
| else: | |
| new_shape = (1,) | |
| else: | |
| raise ValueError(f"Only axes 0 and 1 are supported, not {_axis}") | |
| new_shapes.append(new_shape) | |
| return tuple(np.broadcast_shapes(*new_shapes)) | |
| def broadcast_array_to( | |
| arr: tp.ArrayLike, | |
| target_shape: tp.ShapeLike, | |
| axis: tp.Optional[int] = None, | |
| expand_axis: tp.Optional[int] = None, | |
| ) -> tp.Array: | |
| """Broadcast an array-like object to a target shape using vectorbt's broadcasting rules.""" | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if expand_axis is None: | |
| expand_axis = broadcasting_cfg["expand_axis"] | |
| arr = np.asarray(arr) | |
| target_shape = to_tuple_shape(target_shape) | |
| if len(target_shape) not in (1, 2): | |
| raise ValueError(f"Target shape must have either 1 or 2 dimensions, not {len(target_shape)}") | |
| if len(target_shape) == 2: | |
| new_arr = to_2d_array(arr, expand_axis=expand_axis) | |
| else: | |
| new_arr = to_1d_array(arr) | |
| if axis is not None: | |
| if axis == 0: | |
| if len(target_shape) == 2: | |
| target_shape = (target_shape[0], new_arr.shape[1]) | |
| else: | |
| target_shape = (target_shape[0],) | |
| elif axis == 1: | |
| target_shape = (new_arr.shape[0], target_shape[1]) | |
| else: | |
| raise ValueError(f"Only axes 0 and 1 are supported, not {axis}") | |
| return np.broadcast_to(new_arr, target_shape) | |
| def broadcast_arrays( | |
| *arrs: tp.ArrayLike, | |
| target_shape: tp.Optional[tp.ShapeLike] = None, | |
| axis: tp.Optional[tp.MaybeSequence[int]] = None, | |
| expand_axis: tp.Optional[tp.MaybeSequence[int]] = None, | |
| ) -> tp.Tuple[tp.Array, ...]: | |
| """Broadcast array-like objects using vectorbt's broadcasting rules. | |
| Optionally to a target shape.""" | |
| if target_shape is None: | |
| shapes = [] | |
| for arr in arrs: | |
| shapes.append(np.asarray(arr).shape) | |
| target_shape = broadcast_shapes(*shapes, axis=axis, expand_axis=expand_axis) | |
| new_arrs = [] | |
| for i, arr in enumerate(arrs): | |
| if axis is not None: | |
| if checks.is_sequence(axis): | |
| _axis = axis[i] | |
| else: | |
| _axis = axis | |
| else: | |
| _axis = None | |
| if expand_axis is not None: | |
| if checks.is_sequence(expand_axis): | |
| _expand_axis = expand_axis[i] | |
| else: | |
| _expand_axis = expand_axis | |
| else: | |
| _expand_axis = None | |
| new_arr = broadcast_array_to(arr, target_shape, axis=_axis, expand_axis=_expand_axis) | |
| new_arrs.append(new_arr) | |
| return tuple(new_arrs) | |
| IndexFromLike = tp.Union[None, str, int, tp.Any] | |
| """Any object that can be coerced into a `index_from` argument.""" | |
| def broadcast_index( | |
| objs: tp.Sequence[tp.AnyArray], | |
| to_shape: tp.Shape, | |
| index_from: IndexFromLike = None, | |
| axis: int = 0, | |
| ignore_sr_names: tp.Optional[bool] = None, | |
| ignore_ranges: tp.Optional[bool] = None, | |
| check_index_names: tp.Optional[bool] = None, | |
| **clean_index_kwargs, | |
| ) -> tp.Optional[tp.Index]: | |
| """Produce a broadcast index/columns. | |
| Args: | |
| objs (iterable of array_like): Array-like objects. | |
| to_shape (tuple of int): Target shape. | |
| index_from (any): Broadcasting rule for this index/these columns. | |
| Accepts the following values: | |
| * 'keep' or None - keep the original index/columns of the objects in `objs` | |
| * 'stack' - stack different indexes/columns using `vectorbtpro.base.indexes.stack_indexes` | |
| * 'strict' - ensure that all Pandas objects have the same index/columns | |
| * 'reset' - reset any index/columns (they become a simple range) | |
| * integer - use the index/columns of the i-th object in `objs` | |
| * everything else will be converted to `pd.Index` | |
| axis (int): Set to 0 for index and 1 for columns. | |
| ignore_sr_names (bool): Whether to ignore Series names if they are in conflict. | |
| Conflicting Series names are those that are different but not None. | |
| ignore_ranges (bool): Whether to ignore indexes of type `pd.RangeIndex`. | |
| check_index_names (bool): See `vectorbtpro.utils.checks.is_index_equal`. | |
| **clean_index_kwargs: Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`. | |
| For defaults, see `vectorbtpro._settings.broadcasting`. | |
| !!! note | |
| Series names are treated as columns with a single element but without a name. | |
| If a column level without a name loses its meaning, better to convert Series to DataFrames | |
| with one column prior to broadcasting. If the name of a Series is not that important, | |
| better to drop it altogether by setting it to None. | |
| """ | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if ignore_sr_names is None: | |
| ignore_sr_names = broadcasting_cfg["ignore_sr_names"] | |
| if check_index_names is None: | |
| check_index_names = broadcasting_cfg["check_index_names"] | |
| index_str = "columns" if axis == 1 else "index" | |
| to_shape_2d = (to_shape[0], 1) if len(to_shape) == 1 else to_shape | |
| maxlen = to_shape_2d[1] if axis == 1 else to_shape_2d[0] | |
| new_index = None | |
| objs = list(objs) | |
| if index_from is None or (isinstance(index_from, str) and index_from.lower() == "keep"): | |
| return None | |
| if isinstance(index_from, int): | |
| if not checks.is_pandas(objs[index_from]): | |
| raise TypeError(f"Argument under index {index_from} must be a pandas object") | |
| new_index = indexes.get_index(objs[index_from], axis) | |
| elif isinstance(index_from, str): | |
| if index_from.lower() == "reset": | |
| new_index = pd.RangeIndex(start=0, stop=maxlen, step=1) | |
| elif index_from.lower() in ("stack", "strict"): | |
| last_index = None | |
| index_conflict = False | |
| for obj in objs: | |
| if checks.is_pandas(obj): | |
| index = indexes.get_index(obj, axis) | |
| if last_index is not None: | |
| if not checks.is_index_equal(index, last_index, check_names=check_index_names): | |
| index_conflict = True | |
| last_index = index | |
| continue | |
| if not index_conflict: | |
| new_index = last_index | |
| else: | |
| for obj in objs: | |
| if checks.is_pandas(obj): | |
| index = indexes.get_index(obj, axis) | |
| if axis == 1 and checks.is_series(obj) and ignore_sr_names: | |
| continue | |
| if checks.is_default_index(index): | |
| continue | |
| if new_index is None: | |
| new_index = index | |
| else: | |
| if checks.is_index_equal(index, new_index, check_names=check_index_names): | |
| continue | |
| if index_from.lower() == "strict": | |
| raise ValueError( | |
| f"Arrays have different index. Broadcasting {index_str} " | |
| f"is not allowed when {index_str}_from=strict" | |
| ) | |
| if len(index) != len(new_index): | |
| if len(index) > 1 and len(new_index) > 1: | |
| raise ValueError("Indexes could not be broadcast together") | |
| if len(index) > len(new_index): | |
| new_index = indexes.repeat_index(new_index, len(index), ignore_ranges=ignore_ranges) | |
| elif len(index) < len(new_index): | |
| index = indexes.repeat_index(index, len(new_index), ignore_ranges=ignore_ranges) | |
| new_index = indexes.stack_indexes([new_index, index], **clean_index_kwargs) | |
| else: | |
| raise ValueError(f"Invalid value '{index_from}' for {'columns' if axis == 1 else 'index'}_from") | |
| else: | |
| if not isinstance(index_from, pd.Index): | |
| index_from = pd.Index(index_from) | |
| new_index = index_from | |
| if new_index is not None: | |
| if maxlen > len(new_index): | |
| if isinstance(index_from, str) and index_from.lower() == "strict": | |
| raise ValueError(f"Broadcasting {index_str} is not allowed when {index_str}_from=strict") | |
| if maxlen > 1 and len(new_index) > 1: | |
| raise ValueError("Indexes could not be broadcast together") | |
| new_index = indexes.repeat_index(new_index, maxlen, ignore_ranges=ignore_ranges) | |
| else: | |
| new_index = pd.RangeIndex(start=0, stop=maxlen, step=1) | |
| return new_index | |
| def wrap_broadcasted( | |
| new_obj: tp.Array, | |
| old_obj: tp.Optional[tp.AnyArray] = None, | |
| axis: tp.Optional[int] = None, | |
| is_pd: bool = False, | |
| new_index: tp.Optional[tp.Index] = None, | |
| new_columns: tp.Optional[tp.Index] = None, | |
| ignore_ranges: tp.Optional[bool] = None, | |
| ) -> tp.AnyArray: | |
| """If the newly brodcasted array was originally a Pandas object, make it Pandas object again | |
| and assign it the newly broadcast index/columns.""" | |
| if is_pd: | |
| if axis == 0: | |
| new_columns = None | |
| elif axis == 1: | |
| new_index = None | |
| if old_obj is not None and checks.is_pandas(old_obj): | |
| if new_index is None: | |
| old_index = indexes.get_index(old_obj, 0) | |
| if old_obj.shape[0] == new_obj.shape[0]: | |
| new_index = old_index | |
| else: | |
| new_index = indexes.repeat_index(old_index, new_obj.shape[0], ignore_ranges=ignore_ranges) | |
| if new_columns is None: | |
| old_columns = indexes.get_index(old_obj, 1) | |
| new_ncols = new_obj.shape[1] if new_obj.ndim == 2 else 1 | |
| if len(old_columns) == new_ncols: | |
| new_columns = old_columns | |
| else: | |
| new_columns = indexes.repeat_index(old_columns, new_ncols, ignore_ranges=ignore_ranges) | |
| if new_obj.ndim == 2: | |
| return pd.DataFrame(new_obj, index=new_index, columns=new_columns) | |
| if new_columns is not None and len(new_columns) == 1: | |
| name = new_columns[0] | |
| if name == 0: | |
| name = None | |
| else: | |
| name = None | |
| return pd.Series(new_obj, index=new_index, name=name) | |
| return new_obj | |
| def align_pd_arrays( | |
| *objs: tp.AnyArray, | |
| align_index: bool = True, | |
| align_columns: bool = True, | |
| to_index: tp.Optional[tp.Index] = None, | |
| to_columns: tp.Optional[tp.Index] = None, | |
| axis: tp.Optional[tp.MaybeSequence[int]] = None, | |
| reindex_kwargs: tp.KwargsLikeSequence = None, | |
| ) -> tp.MaybeTuple[tp.ArrayLike]: | |
| """Align Pandas arrays against common index and/or column levels using reindexing | |
| and `vectorbtpro.base.indexes.align_indexes` respectively.""" | |
| objs = list(objs) | |
| if align_index: | |
| indexes_to_align = [] | |
| for i in range(len(objs)): | |
| if axis is not None: | |
| if checks.is_sequence(axis): | |
| _axis = axis[i] | |
| else: | |
| _axis = axis | |
| else: | |
| _axis = None | |
| if _axis in (None, 0): | |
| if checks.is_pandas(objs[i]): | |
| if not checks.is_default_index(objs[i].index): | |
| indexes_to_align.append(i) | |
| if (len(indexes_to_align) > 0 and to_index is not None) or len(indexes_to_align) > 1: | |
| if to_index is None: | |
| new_index = None | |
| index_changed = False | |
| for i in indexes_to_align: | |
| arg_index = objs[i].index | |
| if new_index is None: | |
| new_index = arg_index | |
| else: | |
| if not checks.is_index_equal(new_index, arg_index): | |
| if new_index.dtype != arg_index.dtype: | |
| raise ValueError("Indexes to be aligned must have the same data type") | |
| new_index = new_index.union(arg_index) | |
| index_changed = True | |
| else: | |
| new_index = to_index | |
| index_changed = True | |
| if index_changed: | |
| for i in indexes_to_align: | |
| if to_index is None or not checks.is_index_equal(objs[i].index, to_index): | |
| if objs[i].index.has_duplicates: | |
| raise ValueError(f"Index at position {i} contains duplicates") | |
| if not objs[i].index.is_monotonic_increasing: | |
| raise ValueError(f"Index at position {i} is not monotonically increasing") | |
| _reindex_kwargs = resolve_dict(reindex_kwargs, i=i) | |
| was_bool = (isinstance(objs[i], pd.Series) and objs[i].dtype == "bool") or ( | |
| isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "bool").all() | |
| ) | |
| objs[i] = objs[i].reindex(new_index, **_reindex_kwargs) | |
| is_object = (isinstance(objs[i], pd.Series) and objs[i].dtype == "object") or ( | |
| isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "object").all() | |
| ) | |
| if was_bool and is_object: | |
| objs[i] = objs[i].astype(None) | |
| if align_columns: | |
| columns_to_align = [] | |
| for i in range(len(objs)): | |
| if axis is not None: | |
| if checks.is_sequence(axis): | |
| _axis = axis[i] | |
| else: | |
| _axis = axis | |
| else: | |
| _axis = None | |
| if _axis in (None, 1): | |
| if checks.is_frame(objs[i]) and len(objs[i].columns) > 1: | |
| if not checks.is_default_index(objs[i].columns): | |
| columns_to_align.append(i) | |
| if (len(columns_to_align) > 0 and to_columns is not None) or len(columns_to_align) > 1: | |
| indexes_ = [objs[i].columns for i in columns_to_align] | |
| if to_columns is not None: | |
| indexes_.append(to_columns) | |
| if len(set(map(len, indexes_))) > 1: | |
| col_indices = indexes.align_indexes(*indexes_) | |
| for i in columns_to_align: | |
| objs[i] = objs[i].iloc[:, col_indices[columns_to_align.index(i)]] | |
| if len(objs) == 1: | |
| return objs[0] | |
| return tuple(objs) | |
| @define | |
| class BCO(DefineMixin): | |
| """Class that represents an object passed to `broadcast`. | |
| If any value is None, mostly defaults to the global value passed to `broadcast`.""" | |
| value: tp.Any = define.field() | |
| """Value of the object.""" | |
| axis: tp.Optional[int] = define.field(default=None) | |
| """Axis to broadcast. | |
| Set to None to broadcast all axes.""" | |
| to_pd: tp.Optional[bool] = define.field(default=None) | |
| """Whether to convert the output array to a Pandas object.""" | |
| keep_flex: tp.Optional[bool] = define.field(default=None) | |
| """Whether to keep the raw version of the output for flexible indexing. | |
| Only makes sure that the array can broadcast to the target shape.""" | |
| min_ndim: tp.Optional[int] = define.field(default=None) | |
| """Minimum number of dimensions.""" | |
| expand_axis: tp.Optional[int] = define.field(default=None) | |
| """Axis to expand if the array is 1-dim but the target shape is 2-dim.""" | |
| post_func: tp.Optional[tp.Callable] = define.field(default=None) | |
| """Function to post-process the output array.""" | |
| require_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None) | |
| """Keyword arguments passed to `np.require`.""" | |
| reindex_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None) | |
| """Keyword arguments passed to `pd.DataFrame.reindex`.""" | |
| merge_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None) | |
| """Keyword arguments passed to `vectorbtpro.base.merging.column_stack_merge`.""" | |
| context: tp.KwargsLike = define.field(default=None) | |
| """Context used in evaluation of templates. | |
| Will be merged over `template_context`.""" | |
| @define | |
| class Default(DefineMixin): | |
| """Class for wrapping default values.""" | |
| value: tp.Any = define.field() | |
| """Default value.""" | |
| @define | |
| class Ref(DefineMixin): | |
| """Class for wrapping references to other values.""" | |
| key: tp.Hashable = define.field() | |
| """Reference to another key.""" | |
| def resolve_ref(dct: dict, k: tp.Hashable, inside_bco: bool = False, keep_wrap_default: bool = False) -> tp.Any: | |
| """Resolve a potential reference.""" | |
| v = dct[k] | |
| is_default = False | |
| if isinstance(v, Default): | |
| v = v.value | |
| is_default = True | |
| if isinstance(v, Ref): | |
| new_v = resolve_ref(dct, v.key, inside_bco=inside_bco) | |
| if keep_wrap_default and is_default: | |
| return Default(new_v) | |
| return new_v | |
| if isinstance(v, BCO) and inside_bco: | |
| v = v.value | |
| is_default = False | |
| if isinstance(v, Default): | |
| v = v.value | |
| is_default = True | |
| if isinstance(v, Ref): | |
| new_v = resolve_ref(dct, v.key, inside_bco=inside_bco) | |
| if keep_wrap_default and is_default: | |
| return Default(new_v) | |
| return new_v | |
| return v | |
| def broadcast( | |
| *objs, | |
| to_shape: tp.Optional[tp.ShapeLike] = None, | |
| align_index: tp.Optional[bool] = None, | |
| align_columns: tp.Optional[bool] = None, | |
| index_from: tp.Optional[IndexFromLike] = None, | |
| columns_from: tp.Optional[IndexFromLike] = None, | |
| to_frame: tp.Optional[bool] = None, | |
| axis: tp.Optional[tp.MaybeMappingSequence[int]] = None, | |
| to_pd: tp.Optional[tp.MaybeMappingSequence[bool]] = None, | |
| keep_flex: tp.MaybeMappingSequence[tp.Optional[bool]] = None, | |
| min_ndim: tp.MaybeMappingSequence[tp.Optional[int]] = None, | |
| expand_axis: tp.MaybeMappingSequence[tp.Optional[int]] = None, | |
| post_func: tp.MaybeMappingSequence[tp.Optional[tp.Callable]] = None, | |
| require_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None, | |
| reindex_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None, | |
| merge_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None, | |
| tile: tp.Union[None, int, tp.IndexLike] = None, | |
| random_subset: tp.Optional[int] = None, | |
| seed: tp.Optional[int] = None, | |
| keep_wrap_default: tp.Optional[bool] = None, | |
| return_wrapper: bool = False, | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| ignore_sr_names: tp.Optional[bool] = None, | |
| ignore_ranges: tp.Optional[bool] = None, | |
| check_index_names: tp.Optional[bool] = None, | |
| clean_index_kwargs: tp.KwargsLike = None, | |
| template_context: tp.KwargsLike = None, | |
| ) -> tp.Any: | |
| """Bring any array-like object in `objs` to the same shape by using NumPy-like broadcasting. | |
| See [Broadcasting](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). | |
| !!! important | |
| The major difference to NumPy is that one-dimensional arrays will always broadcast against the row axis! | |
| Can broadcast Pandas objects by broadcasting their index/columns with `broadcast_index`. | |
| Args: | |
| *objs: Objects to broadcast. | |
| If the first and only argument is a mapping, will return a dict. | |
| Allows using `BCO`, `Ref`, `Default`, `vectorbtpro.utils.params.Param`, | |
| `vectorbtpro.base.indexing.index_dict`, `vectorbtpro.base.indexing.IdxSetter`, | |
| `vectorbtpro.base.indexing.IdxSetterFactory`, and templates. | |
| If an index dictionary, fills using `vectorbtpro.base.wrapping.ArrayWrapper.fill_and_set`. | |
| to_shape (tuple of int): Target shape. If set, will broadcast every object in `objs` to `to_shape`. | |
| align_index (bool): Whether to align index of Pandas objects using union. | |
| Pass None to use the default. | |
| align_columns (bool): Whether to align columns of Pandas objects using multi-index. | |
| Pass None to use the default. | |
| index_from (any): Broadcasting rule for index. | |
| Pass None to use the default. | |
| columns_from (any): Broadcasting rule for columns. | |
| Pass None to use the default. | |
| to_frame (bool): Whether to convert all Series to DataFrames. | |
| axis (int, sequence or mapping): See `BCO.axis`. | |
| to_pd (bool, sequence or mapping): See `BCO.to_pd`. | |
| If None, converts only if there is at least one Pandas object among them. | |
| keep_flex (bool, sequence or mapping): See `BCO.keep_flex`. | |
| min_ndim (int, sequence or mapping): See `BCO.min_ndim`. | |
| If None, becomes 2 if `keep_flex` is True, otherwise 1. | |
| expand_axis (int, sequence or mapping): See `BCO.expand_axis`. | |
| post_func (callable, sequence or mapping): See `BCO.post_func`. | |
| Applied only when `keep_flex` is False. | |
| require_kwargs (dict, sequence or mapping): See `BCO.require_kwargs`. | |
| This key will be merged with any argument-specific dict. If the mapping contains all keys in | |
| `np.require`, it will be applied to all objects. | |
| reindex_kwargs (dict, sequence or mapping): See `BCO.reindex_kwargs`. | |
| This key will be merged with any argument-specific dict. If the mapping contains all keys in | |
| `pd.DataFrame.reindex`, it will be applied to all objects. | |
| merge_kwargs (dict, sequence or mapping): See `BCO.merge_kwargs`. | |
| This key will be merged with any argument-specific dict. If the mapping contains all keys in | |
| `pd.DataFrame.merge`, it will be applied to all objects. | |
| tile (int or index_like): Tile the final object by the number of times or index. | |
| random_subset (int): Select a random subset of parameter values. | |
| Seed can be set using NumPy before calling this function. | |
| seed (int): Seed to make output deterministic. | |
| keep_wrap_default (bool): Whether to keep wrapping with `vectorbtpro.base.reshaping.Default`. | |
| return_wrapper (bool): Whether to also return the wrapper associated with the operation. | |
| wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`. | |
| ignore_sr_names (bool): See `broadcast_index`. | |
| ignore_ranges (bool): See `broadcast_index`. | |
| check_index_names (bool): See `broadcast_index`. | |
| clean_index_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`. | |
| template_context (dict): Context used to substitute templates. | |
| For defaults, see `vectorbtpro._settings.broadcasting`. | |
| Any keyword argument that can be associated with an object can be passed as | |
| * a const that is applied to all objects, | |
| * a sequence with value per object, and | |
| * a mapping with value per object name and the special key `_def` denoting the default value. | |
| Additionally, any object can be passed wrapped with `BCO`, which ibutes will override | |
| any of the above arguments if not None. | |
| Usage: | |
| * Without broadcasting index and columns: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> v = 0 | |
| >>> a = np.array([1, 2, 3]) | |
| >>> sr = pd.Series([1, 2, 3], index=pd.Index(['x', 'y', 'z']), name='a') | |
| >>> df = pd.DataFrame( | |
| ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]], | |
| ... index=pd.Index(['x2', 'y2', 'z2']), | |
| ... columns=pd.Index(['a2', 'b2', 'c2']), | |
| ... ) | |
| >>> for i in vbt.broadcast( | |
| ... v, a, sr, df, | |
| ... index_from='keep', | |
| ... columns_from='keep', | |
| ... align_index=False | |
| ... ): print(i) | |
| 0 1 2 | |
| 0 0 0 0 | |
| 1 0 0 0 | |
| 2 0 0 0 | |
| 0 1 2 | |
| 0 1 2 3 | |
| 1 1 2 3 | |
| 2 1 2 3 | |
| a a a | |
| x 1 1 1 | |
| y 2 2 2 | |
| z 3 3 3 | |
| a2 b2 c2 | |
| x2 1 2 3 | |
| y2 4 5 6 | |
| z2 7 8 9 | |
| ``` | |
| * Take index and columns from the argument at specific position: | |
| ```pycon | |
| >>> for i in vbt.broadcast( | |
| ... v, a, sr, df, | |
| ... index_from=2, | |
| ... columns_from=3, | |
| ... align_index=False | |
| ... ): print(i) | |
| a2 b2 c2 | |
| x 0 0 0 | |
| y 0 0 0 | |
| z 0 0 0 | |
| a2 b2 c2 | |
| x 1 2 3 | |
| y 1 2 3 | |
| z 1 2 3 | |
| a2 b2 c2 | |
| x 1 1 1 | |
| y 2 2 2 | |
| z 3 3 3 | |
| a2 b2 c2 | |
| x 1 2 3 | |
| y 4 5 6 | |
| z 7 8 9 | |
| ``` | |
| * Broadcast index and columns through stacking: | |
| ```pycon | |
| >>> for i in vbt.broadcast( | |
| ... v, a, sr, df, | |
| ... index_from='stack', | |
| ... columns_from='stack', | |
| ... align_index=False | |
| ... ): print(i) | |
| a2 b2 c2 | |
| x x2 0 0 0 | |
| y y2 0 0 0 | |
| z z2 0 0 0 | |
| a2 b2 c2 | |
| x x2 1 2 3 | |
| y y2 1 2 3 | |
| z z2 1 2 3 | |
| a2 b2 c2 | |
| x x2 1 1 1 | |
| y y2 2 2 2 | |
| z z2 3 3 3 | |
| a2 b2 c2 | |
| x x2 1 2 3 | |
| y y2 4 5 6 | |
| z z2 7 8 9 | |
| ``` | |
| * Set index and columns manually: | |
| ```pycon | |
| >>> for i in vbt.broadcast( | |
| ... v, a, sr, df, | |
| ... index_from=['a', 'b', 'c'], | |
| ... columns_from=['d', 'e', 'f'], | |
| ... align_index=False | |
| ... ): print(i) | |
| d e f | |
| a 0 0 0 | |
| b 0 0 0 | |
| c 0 0 0 | |
| d e f | |
| a 1 2 3 | |
| b 1 2 3 | |
| c 1 2 3 | |
| d e f | |
| a 1 1 1 | |
| b 2 2 2 | |
| c 3 3 3 | |
| d e f | |
| a 1 2 3 | |
| b 4 5 6 | |
| c 7 8 9 | |
| ``` | |
| * Pass arguments as a mapping returns a mapping: | |
| ```pycon | |
| >>> vbt.broadcast( | |
| ... dict(v=v, a=a, sr=sr, df=df), | |
| ... index_from='stack', | |
| ... align_index=False | |
| ... ) | |
| {'v': a2 b2 c2 | |
| x x2 0 0 0 | |
| y y2 0 0 0 | |
| z z2 0 0 0, | |
| 'a': a2 b2 c2 | |
| x x2 1 2 3 | |
| y y2 1 2 3 | |
| z z2 1 2 3, | |
| 'sr': a2 b2 c2 | |
| x x2 1 1 1 | |
| y y2 2 2 2 | |
| z z2 3 3 3, | |
| 'df': a2 b2 c2 | |
| x x2 1 2 3 | |
| y y2 4 5 6 | |
| z z2 7 8 9} | |
| ``` | |
| * Keep all results in a format suitable for flexible indexing apart from one: | |
| ```pycon | |
| >>> vbt.broadcast( | |
| ... dict(v=v, a=a, sr=sr, df=df), | |
| ... index_from='stack', | |
| ... keep_flex=dict(_def=True, df=False), | |
| ... require_kwargs=dict(df=dict(dtype=float)), | |
| ... align_index=False | |
| ... ) | |
| {'v': array([0]), | |
| 'a': array([1, 2, 3]), | |
| 'sr': array([[1], | |
| [2], | |
| [3]]), | |
| 'df': a2 b2 c2 | |
| x x2 1.0 2.0 3.0 | |
| y y2 4.0 5.0 6.0 | |
| z z2 7.0 8.0 9.0} | |
| ``` | |
| * Specify arguments per object using `BCO`: | |
| ```pycon | |
| >>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float)) | |
| >>> vbt.broadcast( | |
| ... dict(v=v, a=a, sr=sr, df=df_bco), | |
| ... index_from='stack', | |
| ... keep_flex=True, | |
| ... align_index=False | |
| ... ) | |
| {'v': array([0]), | |
| 'a': array([1, 2, 3]), | |
| 'sr': array([[1], | |
| [2], | |
| [3]]), | |
| 'df': a2 b2 c2 | |
| x x2 1.0 2.0 3.0 | |
| y y2 4.0 5.0 6.0 | |
| z z2 7.0 8.0 9.0} | |
| ``` | |
| * Introduce a parameter that should build a Cartesian product of its values and other objects: | |
| ```pycon | |
| >>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float)) | |
| >>> p_bco = vbt.BCO(pd.Param([1, 2, 3], name='my_p')) | |
| >>> vbt.broadcast( | |
| ... dict(v=v, a=a, sr=sr, df=df_bco, p=p_bco), | |
| ... index_from='stack', | |
| ... keep_flex=True, | |
| ... align_index=False | |
| ... ) | |
| {'v': array([0]), | |
| 'a': array([1, 2, 3, 1, 2, 3, 1, 2, 3]), | |
| 'sr': array([[1], | |
| [2], | |
| [3]]), | |
| 'df': my_p 1 2 3 | |
| a2 b2 c2 a2 b2 c2 a2 b2 c2 | |
| x x2 1.0 2.0 3.0 1.0 2.0 3.0 1.0 2.0 3.0 | |
| y y2 4.0 5.0 6.0 4.0 5.0 6.0 4.0 5.0 6.0 | |
| z z2 7.0 8.0 9.0 7.0 8.0 9.0 7.0 8.0 9.0, | |
| 'p': array([[1, 1, 1, 2, 2, 2, 3, 3, 3], | |
| [1, 1, 1, 2, 2, 2, 3, 3, 3], | |
| [1, 1, 1, 2, 2, 2, 3, 3, 3]])} | |
| ``` | |
| * Build a Cartesian product of all parameters: | |
| ```pycon | |
| >>> vbt.broadcast( | |
| ... dict( | |
| ... a=vbt.Param([1, 2, 3]), | |
| ... b=vbt.Param(['x', 'y']), | |
| ... c=vbt.Param([False, True]) | |
| ... ) | |
| ... ) | |
| {'a': array([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]]), | |
| 'b': array([['x', 'x', 'y', 'y', 'x', 'x', 'y', 'y', 'x', 'x', 'y', 'y']], dtype='<U1'), | |
| 'c': array([[False, True, False, True, False, True, False, True, False, True, False, True]])} | |
| ``` | |
| * Build a Cartesian product of two groups of parameters - (a, d) and (b, c): | |
| ```pycon | |
| >>> vbt.broadcast( | |
| ... dict( | |
| ... a=vbt.Param([1, 2, 3], level=0), | |
| ... b=vbt.Param(['x', 'y'], level=1), | |
| ... d=vbt.Param([100., 200., 300.], level=0), | |
| ... c=vbt.Param([False, True], level=1) | |
| ... ) | |
| ... ) | |
| {'a': array([[1, 1, 2, 2, 3, 3]]), | |
| 'b': array([['x', 'y', 'x', 'y', 'x', 'y']], dtype='<U1'), | |
| 'd': array([[100., 100., 200., 200., 300., 300.]]), | |
| 'c': array([[False, True, False, True, False, True]])} | |
| ``` | |
| * Select a random subset of parameter combinations: | |
| ```pycon | |
| >>> vbt.broadcast( | |
| ... dict( | |
| ... a=vbt.Param([1, 2, 3]), | |
| ... b=vbt.Param(['x', 'y']), | |
| ... c=vbt.Param([False, True]) | |
| ... ), | |
| ... random_subset=5, | |
| ... seed=42 | |
| ... ) | |
| {'a': array([[1, 2, 3, 3, 3]]), | |
| 'b': array([['x', 'x', 'x', 'x', 'y']], dtype='<U1'), | |
| 'c': array([[False, True, False, True, False]])} | |
| ``` | |
| """ | |
| # Get defaults | |
| from vectorbtpro._settings import settings | |
| broadcasting_cfg = settings["broadcasting"] | |
| if align_index is None: | |
| align_index = broadcasting_cfg["align_index"] | |
| if align_columns is None: | |
| align_columns = broadcasting_cfg["align_columns"] | |
| if index_from is None: | |
| index_from = broadcasting_cfg["index_from"] | |
| if columns_from is None: | |
| columns_from = broadcasting_cfg["columns_from"] | |
| if keep_wrap_default is None: | |
| keep_wrap_default = broadcasting_cfg["keep_wrap_default"] | |
| require_kwargs_per_obj = True | |
| if require_kwargs is not None and checks.is_mapping(require_kwargs): | |
| require_arg_names = get_func_arg_names(np.require) | |
| if set(require_kwargs) <= set(require_arg_names): | |
| require_kwargs_per_obj = False | |
| reindex_kwargs_per_obj = True | |
| if reindex_kwargs is not None and checks.is_mapping(reindex_kwargs): | |
| reindex_arg_names = get_func_arg_names(pd.DataFrame.reindex) | |
| if set(reindex_kwargs) <= set(reindex_arg_names): | |
| reindex_kwargs_per_obj = False | |
| merge_kwargs_per_obj = True | |
| if merge_kwargs is not None and checks.is_mapping(merge_kwargs): | |
| merge_arg_names = get_func_arg_names(pd.DataFrame.merge) | |
| if set(merge_kwargs) <= set(merge_arg_names): | |
| merge_kwargs_per_obj = False | |
| if clean_index_kwargs is None: | |
| clean_index_kwargs = {} | |
| if checks.is_mapping(objs[0]) and not isinstance(objs[0], indexing.index_dict): | |
| if len(objs) > 1: | |
| raise ValueError("Only one argument is allowed when passing a mapping") | |
| all_keys = list(dict(objs[0]).keys()) | |
| objs = list(objs[0].values()) | |
| return_dict = True | |
| else: | |
| objs = list(objs) | |
| all_keys = list(range(len(objs))) | |
| return_dict = False | |
| def _resolve_arg(obj: tp.Any, arg_name: str, global_value: tp.Any, default_value: tp.Any) -> tp.Any: | |
| if isinstance(obj, BCO) and getattr(obj, arg_name) is not None: | |
| return getattr(obj, arg_name) | |
| if checks.is_mapping(global_value): | |
| return global_value.get(k, global_value.get("_def", default_value)) | |
| if checks.is_sequence(global_value): | |
| return global_value[i] | |
| return global_value | |
| # Build BCO instances | |
| none_keys = set() | |
| default_keys = set() | |
| param_keys = set() | |
| special_keys = set() | |
| bco_instances = {} | |
| pool = dict(zip(all_keys, objs)) | |
| for i, k in enumerate(all_keys): | |
| obj = objs[i] | |
| if isinstance(obj, Default): | |
| obj = obj.value | |
| default_keys.add(k) | |
| if isinstance(obj, Ref): | |
| obj = resolve_ref(pool, k) | |
| if isinstance(obj, BCO): | |
| value = obj.value | |
| else: | |
| value = obj | |
| if isinstance(value, Default): | |
| value = value.value | |
| default_keys.add(k) | |
| if isinstance(value, Ref): | |
| value = resolve_ref(pool, k, inside_bco=True) | |
| if value is None: | |
| none_keys.add(k) | |
| continue | |
| _axis = _resolve_arg(obj, "axis", axis, None) | |
| _to_pd = _resolve_arg(obj, "to_pd", to_pd, None) | |
| _keep_flex = _resolve_arg(obj, "keep_flex", keep_flex, None) | |
| if _keep_flex is None: | |
| _keep_flex = broadcasting_cfg["keep_flex"] | |
| _min_ndim = _resolve_arg(obj, "min_ndim", min_ndim, None) | |
| if _min_ndim is None: | |
| _min_ndim = broadcasting_cfg["min_ndim"] | |
| _expand_axis = _resolve_arg(obj, "expand_axis", expand_axis, None) | |
| if _expand_axis is None: | |
| _expand_axis = broadcasting_cfg["expand_axis"] | |
| _post_func = _resolve_arg(obj, "post_func", post_func, None) | |
| if isinstance(obj, BCO) and obj.require_kwargs is not None: | |
| _require_kwargs = obj.require_kwargs | |
| else: | |
| _require_kwargs = None | |
| if checks.is_mapping(require_kwargs) and require_kwargs_per_obj: | |
| _require_kwargs = merge_dicts( | |
| require_kwargs.get("_def", None), | |
| require_kwargs.get(k, None), | |
| _require_kwargs, | |
| ) | |
| elif checks.is_sequence(require_kwargs) and require_kwargs_per_obj: | |
| _require_kwargs = merge_dicts(require_kwargs[i], _require_kwargs) | |
| else: | |
| _require_kwargs = merge_dicts(require_kwargs, _require_kwargs) | |
| if isinstance(obj, BCO) and obj.reindex_kwargs is not None: | |
| _reindex_kwargs = obj.reindex_kwargs | |
| else: | |
| _reindex_kwargs = None | |
| if checks.is_mapping(reindex_kwargs) and reindex_kwargs_per_obj: | |
| _reindex_kwargs = merge_dicts( | |
| reindex_kwargs.get("_def", None), | |
| reindex_kwargs.get(k, None), | |
| _reindex_kwargs, | |
| ) | |
| elif checks.is_sequence(reindex_kwargs) and reindex_kwargs_per_obj: | |
| _reindex_kwargs = merge_dicts(reindex_kwargs[i], _reindex_kwargs) | |
| else: | |
| _reindex_kwargs = merge_dicts(reindex_kwargs, _reindex_kwargs) | |
| if isinstance(obj, BCO) and obj.merge_kwargs is not None: | |
| _merge_kwargs = obj.merge_kwargs | |
| else: | |
| _merge_kwargs = None | |
| if checks.is_mapping(merge_kwargs) and merge_kwargs_per_obj: | |
| _merge_kwargs = merge_dicts( | |
| merge_kwargs.get("_def", None), | |
| merge_kwargs.get(k, None), | |
| _merge_kwargs, | |
| ) | |
| elif checks.is_sequence(merge_kwargs) and merge_kwargs_per_obj: | |
| _merge_kwargs = merge_dicts(merge_kwargs[i], _merge_kwargs) | |
| else: | |
| _merge_kwargs = merge_dicts(merge_kwargs, _merge_kwargs) | |
| if isinstance(obj, BCO): | |
| _context = merge_dicts(template_context, obj.context) | |
| else: | |
| _context = template_context | |
| if isinstance(value, Param): | |
| param_keys.add(k) | |
| elif isinstance(value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory, CustomTemplate)): | |
| special_keys.add(k) | |
| else: | |
| value = to_any_array(value) | |
| bco_instances[k] = BCO( | |
| value, | |
| axis=_axis, | |
| to_pd=_to_pd, | |
| keep_flex=_keep_flex, | |
| min_ndim=_min_ndim, | |
| expand_axis=_expand_axis, | |
| post_func=_post_func, | |
| require_kwargs=_require_kwargs, | |
| reindex_kwargs=_reindex_kwargs, | |
| merge_kwargs=_merge_kwargs, | |
| context=_context, | |
| ) | |
| # Check whether we should broadcast Pandas metadata and work on 2-dim data | |
| is_pd = False | |
| is_2d = False | |
| old_objs = {} | |
| obj_axis = {} | |
| obj_reindex_kwargs = {} | |
| for k, bco_obj in bco_instances.items(): | |
| if k in none_keys or k in param_keys or k in special_keys: | |
| continue | |
| obj = bco_obj.value | |
| if obj.ndim > 1: | |
| is_2d = True | |
| if checks.is_pandas(obj): | |
| is_pd = True | |
| if bco_obj.to_pd is not None and bco_obj.to_pd: | |
| is_pd = True | |
| old_objs[k] = obj | |
| obj_axis[k] = bco_obj.axis | |
| obj_reindex_kwargs[k] = bco_obj.reindex_kwargs | |
| if to_shape is not None: | |
| if isinstance(to_shape, int): | |
| to_shape = (to_shape,) | |
| if len(to_shape) > 1: | |
| is_2d = True | |
| if to_frame is not None: | |
| is_2d = to_frame | |
| if to_pd is not None: | |
| is_pd = to_pd or (return_wrapper and is_pd) | |
| # Align pandas arrays | |
| if index_from is not None and not isinstance(index_from, (int, str, pd.Index)): | |
| index_from = pd.Index(index_from) | |
| if columns_from is not None and not isinstance(columns_from, (int, str, pd.Index)): | |
| columns_from = pd.Index(columns_from) | |
| aligned_objs = align_pd_arrays( | |
| *old_objs.values(), | |
| align_index=align_index, | |
| align_columns=align_columns, | |
| to_index=index_from if isinstance(index_from, pd.Index) else None, | |
| to_columns=columns_from if isinstance(columns_from, pd.Index) else None, | |
| axis=list(obj_axis.values()), | |
| reindex_kwargs=list(obj_reindex_kwargs.values()), | |
| ) | |
| if not isinstance(aligned_objs, tuple): | |
| aligned_objs = (aligned_objs,) | |
| aligned_objs = dict(zip(old_objs.keys(), aligned_objs)) | |
| # Convert to NumPy | |
| ready_objs = {} | |
| for k, obj in aligned_objs.items(): | |
| _expand_axis = bco_instances[k].expand_axis | |
| new_obj = np.asarray(obj) | |
| if is_2d and new_obj.ndim == 1: | |
| if isinstance(obj, pd.Series): | |
| new_obj = new_obj[:, None] | |
| else: | |
| new_obj = np.expand_dims(new_obj, _expand_axis) | |
| ready_objs[k] = new_obj | |
| # Get final shape | |
| if to_shape is None: | |
| try: | |
| to_shape = broadcast_shapes( | |
| *map(lambda x: x.shape, ready_objs.values()), | |
| axis=list(obj_axis.values()), | |
| ) | |
| except ValueError: | |
| arr_shapes = {} | |
| for i, k in enumerate(bco_instances): | |
| if k in none_keys or k in param_keys or k in special_keys: | |
| continue | |
| if len(ready_objs[k].shape) > 0: | |
| arr_shapes[k] = ready_objs[k].shape | |
| raise ValueError("Could not broadcast shapes: %s" % str(arr_shapes)) | |
| if not isinstance(to_shape, tuple): | |
| to_shape = (to_shape,) | |
| if len(to_shape) == 0: | |
| to_shape = (1,) | |
| to_shape_2d = to_shape if len(to_shape) > 1 else (*to_shape, 1) | |
| if is_pd: | |
| # Decide on index and columns | |
| # NOTE: Important to pass aligned_objs, not ready_objs, to preserve original shape info | |
| new_index = broadcast_index( | |
| [v for k, v in aligned_objs.items() if obj_axis[k] in (None, 0)], | |
| to_shape, | |
| index_from=index_from, | |
| axis=0, | |
| ignore_sr_names=ignore_sr_names, | |
| ignore_ranges=ignore_ranges, | |
| check_index_names=check_index_names, | |
| **clean_index_kwargs, | |
| ) | |
| new_columns = broadcast_index( | |
| [v for k, v in aligned_objs.items() if obj_axis[k] in (None, 1)], | |
| to_shape, | |
| index_from=columns_from, | |
| axis=1, | |
| ignore_sr_names=ignore_sr_names, | |
| ignore_ranges=ignore_ranges, | |
| check_index_names=check_index_names, | |
| **clean_index_kwargs, | |
| ) | |
| else: | |
| new_index = pd.RangeIndex(stop=to_shape_2d[0]) | |
| new_columns = pd.RangeIndex(stop=to_shape_2d[1]) | |
| # Build a product | |
| param_product = None | |
| param_columns = None | |
| n_params = 0 | |
| if len(param_keys) > 0: | |
| # Combine parameters | |
| param_dct = {} | |
| for k, bco_obj in bco_instances.items(): | |
| if k not in param_keys: | |
| continue | |
| param_dct[k] = bco_obj.value | |
| param_product, param_columns = combine_params( | |
| param_dct, | |
| random_subset=random_subset, | |
| seed=seed, | |
| clean_index_kwargs=clean_index_kwargs, | |
| ) | |
| n_params = len(param_columns) | |
| # Combine parameter columns with new columns | |
| if param_columns is not None and new_columns is not None: | |
| new_columns = indexes.combine_indexes([param_columns, new_columns], **clean_index_kwargs) | |
| # Tile | |
| if tile is not None: | |
| if isinstance(tile, int): | |
| if new_columns is not None: | |
| new_columns = indexes.tile_index(new_columns, tile) | |
| else: | |
| if new_columns is not None: | |
| new_columns = indexes.combine_indexes([tile, new_columns], **clean_index_kwargs) | |
| tile = len(tile) | |
| n_params = max(n_params, 1) * tile | |
| # Build wrapper | |
| if n_params == 0: | |
| new_shape = to_shape | |
| else: | |
| new_shape = (to_shape_2d[0], to_shape_2d[1] * n_params) | |
| wrapper = wrapping.ArrayWrapper.from_shape( | |
| new_shape, | |
| **merge_dicts( | |
| dict( | |
| index=new_index, | |
| columns=new_columns, | |
| ), | |
| wrapper_kwargs, | |
| ), | |
| ) | |
| def _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis): | |
| if _min_ndim is None: | |
| if _keep_flex: | |
| _min_ndim = 2 | |
| else: | |
| _min_ndim = 1 | |
| if _min_ndim not in (1, 2): | |
| raise ValueError("Argument min_ndim must be either 1 or 2") | |
| if _min_ndim in (1, 2) and new_obj.ndim == 0: | |
| new_obj = new_obj[None] | |
| if _min_ndim == 2 and new_obj.ndim == 1: | |
| if len(to_shape) == 1: | |
| new_obj = new_obj[:, None] | |
| else: | |
| new_obj = np.expand_dims(new_obj, _expand_axis) | |
| return new_obj | |
| # Perform broadcasting | |
| aligned_objs2 = {} | |
| new_objs = {} | |
| for i, k in enumerate(all_keys): | |
| if k in none_keys or k in special_keys: | |
| continue | |
| _keep_flex = bco_instances[k].keep_flex | |
| _min_ndim = bco_instances[k].min_ndim | |
| _axis = bco_instances[k].axis | |
| _expand_axis = bco_instances[k].expand_axis | |
| _merge_kwargs = bco_instances[k].merge_kwargs | |
| _context = bco_instances[k].context | |
| must_reset_index = _merge_kwargs.get("reset_index", None) not in (None, False) | |
| _reindex_kwargs = resolve_dict(bco_instances[k].reindex_kwargs) | |
| _fill_value = _reindex_kwargs.get("fill_value", np.nan) | |
| if k in param_keys: | |
| # Broadcast parameters | |
| from vectorbtpro.base.merging import column_stack_merge | |
| if _axis == 0: | |
| raise ValueError("Parameters do not support broadcasting with axis=0") | |
| obj = param_product[k] | |
| new_obj = [] | |
| any_needs_broadcasting = False | |
| all_forced_broadcast = True | |
| for o in obj: | |
| if isinstance(o, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)): | |
| o = wrapper.fill_and_set( | |
| o, | |
| fill_value=_fill_value, | |
| keep_flex=_keep_flex, | |
| ) | |
| elif isinstance(o, CustomTemplate): | |
| context = merge_dicts( | |
| dict( | |
| bco_instances=bco_instances, | |
| wrapper=wrapper, | |
| obj_name=k, | |
| bco=bco_instances[k], | |
| ), | |
| _context, | |
| ) | |
| o = o.substitute(context, eval_id="broadcast") | |
| o = to_2d_array(o) | |
| if not _keep_flex: | |
| needs_broadcasting = True | |
| elif o.shape[0] > 1: | |
| needs_broadcasting = True | |
| elif o.shape[1] > 1 and o.shape[1] != to_shape_2d[1]: | |
| needs_broadcasting = True | |
| else: | |
| needs_broadcasting = False | |
| if needs_broadcasting: | |
| any_needs_broadcasting = True | |
| o = broadcast_array_to(o, to_shape_2d, axis=_axis) | |
| elif o.size == 1: | |
| all_forced_broadcast = False | |
| o = np.repeat(o, to_shape_2d[1], axis=1) | |
| else: | |
| all_forced_broadcast = False | |
| new_obj.append(o) | |
| if any_needs_broadcasting and not all_forced_broadcast: | |
| new_obj2 = [] | |
| for o in new_obj: | |
| if o.shape[1] != to_shape_2d[1] or (not must_reset_index and o.shape[0] != to_shape_2d[0]): | |
| o = broadcast_array_to(o, to_shape_2d, axis=_axis) | |
| new_obj2.append(o) | |
| new_obj = new_obj2 | |
| obj = column_stack_merge(new_obj, **_merge_kwargs) | |
| if tile is not None: | |
| obj = np.tile(obj, (1, tile)) | |
| old_obj = obj | |
| new_obj = obj | |
| else: | |
| # Broadcast regular objects | |
| old_obj = aligned_objs[k] | |
| new_obj = ready_objs[k] | |
| if _axis in (None, 0) and new_obj.ndim >= 1 and new_obj.shape[0] > 1 and new_obj.shape[0] != to_shape[0]: | |
| raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}") | |
| if _axis in (None, 1) and new_obj.ndim == 2 and new_obj.shape[1] > 1 and new_obj.shape[1] != to_shape[1]: | |
| raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}") | |
| if _keep_flex: | |
| if n_params > 0 and _axis in (None, 1): | |
| if len(to_shape) == 1: | |
| if new_obj.ndim == 1 and new_obj.shape[0] > 1: | |
| new_obj = new_obj[:, None] # product changes is_2d behavior | |
| else: | |
| if new_obj.ndim == 1 and new_obj.shape[0] > 1: | |
| new_obj = np.tile(new_obj, n_params) | |
| elif new_obj.ndim == 2 and new_obj.shape[1] > 1: | |
| new_obj = np.tile(new_obj, (1, n_params)) | |
| else: | |
| new_obj = broadcast_array_to(new_obj, to_shape, axis=_axis) | |
| if n_params > 0 and _axis in (None, 1): | |
| if new_obj.ndim == 1: | |
| new_obj = new_obj[:, None] # product changes is_2d behavior | |
| new_obj = np.tile(new_obj, (1, n_params)) | |
| new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis) | |
| aligned_objs2[k] = old_obj | |
| new_objs[k] = new_obj | |
| # Resolve special objects | |
| new_objs2 = {} | |
| for i, k in enumerate(all_keys): | |
| if k in none_keys: | |
| continue | |
| if k in special_keys: | |
| bco = bco_instances[k] | |
| if isinstance(bco.value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)): | |
| _is_pd = bco.to_pd | |
| if _is_pd is None: | |
| _is_pd = is_pd | |
| _keep_flex = bco.keep_flex | |
| _min_ndim = bco.min_ndim | |
| _expand_axis = bco.expand_axis | |
| _reindex_kwargs = resolve_dict(bco.reindex_kwargs) | |
| _fill_value = _reindex_kwargs.get("fill_value", np.nan) | |
| new_obj = wrapper.fill_and_set( | |
| bco.value, | |
| fill_value=_fill_value, | |
| keep_flex=_keep_flex, | |
| ) | |
| if not _is_pd and not _keep_flex: | |
| new_obj = new_obj.values | |
| new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis) | |
| elif isinstance(bco.value, CustomTemplate): | |
| context = merge_dicts( | |
| dict( | |
| bco_instances=bco_instances, | |
| new_objs=new_objs, | |
| wrapper=wrapper, | |
| obj_name=k, | |
| bco=bco, | |
| ), | |
| bco.context, | |
| ) | |
| new_obj = bco.value.substitute(context, eval_id="broadcast") | |
| else: | |
| raise TypeError(f"Special type {type(bco.value)} is not supported") | |
| else: | |
| new_obj = new_objs[k] | |
| # Force to match requirements | |
| new_obj = np.require(new_obj, **resolve_dict(bco_instances[k].require_kwargs)) | |
| new_objs2[k] = new_obj | |
| # Perform wrapping and post-processing | |
| new_objs3 = {} | |
| for i, k in enumerate(all_keys): | |
| if k in none_keys: | |
| continue | |
| new_obj = new_objs2[k] | |
| _axis = bco_instances[k].axis | |
| _keep_flex = bco_instances[k].keep_flex | |
| if not _keep_flex: | |
| # Wrap array | |
| _is_pd = bco_instances[k].to_pd | |
| if _is_pd is None: | |
| _is_pd = is_pd | |
| new_obj = wrap_broadcasted( | |
| new_obj, | |
| old_obj=aligned_objs2[k] if k not in special_keys else None, | |
| axis=_axis, | |
| is_pd=_is_pd, | |
| new_index=new_index, | |
| new_columns=new_columns, | |
| ignore_ranges=ignore_ranges, | |
| ) | |
| # Post-process array | |
| _post_func = bco_instances[k].post_func | |
| if _post_func is not None: | |
| new_obj = _post_func(new_obj) | |
| new_objs3[k] = new_obj | |
| # Prepare outputs | |
| return_objs = [] | |
| for k in all_keys: | |
| if k not in none_keys: | |
| if k in default_keys and keep_wrap_default: | |
| return_objs.append(Default(new_objs3[k])) | |
| else: | |
| return_objs.append(new_objs3[k]) | |
| else: | |
| if k in default_keys and keep_wrap_default: | |
| return_objs.append(Default(None)) | |
| else: | |
| return_objs.append(None) | |
| if return_dict: | |
| return_objs = dict(zip(all_keys, return_objs)) | |
| else: | |
| return_objs = tuple(return_objs) | |
| if len(return_objs) > 1 or return_dict: | |
| if return_wrapper: | |
| return return_objs, wrapper | |
| return return_objs | |
| if return_wrapper: | |
| return return_objs[0], wrapper | |
| return return_objs[0] | |
| def broadcast_to( | |
| arg1: tp.ArrayLike, | |
| arg2: tp.Union[tp.ArrayLike, tp.ShapeLike, wrapping.ArrayWrapper], | |
| to_pd: tp.Optional[bool] = None, | |
| index_from: tp.Optional[IndexFromLike] = None, | |
| columns_from: tp.Optional[IndexFromLike] = None, | |
| **kwargs, | |
| ) -> tp.Any: | |
| """Broadcast `arg1` to `arg2`. | |
| Argument `arg2` can be a shape, an instance of `vectorbtpro.base.wrapping.ArrayWrapper`, | |
| or any array-like object. | |
| Pass None to `index_from`/`columns_from` to use index/columns of the second argument. | |
| Keyword arguments `**kwargs` are passed to `broadcast`. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.reshaping import broadcast_to | |
| >>> a = np.array([1, 2, 3]) | |
| >>> sr = pd.Series([4, 5, 6], index=pd.Index(['x', 'y', 'z']), name='a') | |
| >>> broadcast_to(a, sr) | |
| x 1 | |
| y 2 | |
| z 3 | |
| Name: a, dtype: int64 | |
| >>> broadcast_to(sr, a) | |
| array([4, 5, 6]) | |
| ``` | |
| """ | |
| if checks.is_int(arg2) or isinstance(arg2, tuple): | |
| arg2 = to_tuple_shape(arg2) | |
| if isinstance(arg2, tuple): | |
| to_shape = arg2 | |
| elif isinstance(arg2, wrapping.ArrayWrapper): | |
| to_pd = True | |
| if index_from is None: | |
| index_from = arg2.index | |
| if columns_from is None: | |
| columns_from = arg2.columns | |
| to_shape = arg2.shape | |
| else: | |
| arg2 = to_any_array(arg2) | |
| if to_pd is None: | |
| to_pd = checks.is_pandas(arg2) | |
| if to_pd: | |
| # Take index and columns from arg2 | |
| if index_from is None: | |
| index_from = indexes.get_index(arg2, 0) | |
| if columns_from is None: | |
| columns_from = indexes.get_index(arg2, 1) | |
| to_shape = arg2.shape | |
| return broadcast( | |
| arg1, | |
| to_shape=to_shape, | |
| to_pd=to_pd, | |
| index_from=index_from, | |
| columns_from=columns_from, | |
| **kwargs, | |
| ) | |
| def broadcast_to_array_of(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> tp.Array: | |
| """Broadcast `arg1` to the shape `(1, *arg2.shape)`. | |
| `arg1` must be either a scalar, a 1-dim array, or have 1 dimension more than `arg2`. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.reshaping import broadcast_to_array_of | |
| >>> broadcast_to_array_of([0.1, 0.2], np.empty((2, 2))) | |
| [[[0.1 0.1] | |
| [0.1 0.1]] | |
| [[0.2 0.2] | |
| [0.2 0.2]]] | |
| ``` | |
| """ | |
| arg1 = np.asarray(arg1) | |
| arg2 = np.asarray(arg2) | |
| if arg1.ndim == arg2.ndim + 1: | |
| if arg1.shape[1:] == arg2.shape: | |
| return arg1 | |
| # From here on arg1 can be only a 1-dim array | |
| if arg1.ndim == 0: | |
| arg1 = to_1d(arg1) | |
| checks.assert_ndim(arg1, 1) | |
| if arg2.ndim == 0: | |
| return arg1 | |
| for i in range(arg2.ndim): | |
| arg1 = np.expand_dims(arg1, axis=-1) | |
| return np.tile(arg1, (1, *arg2.shape)) | |
| def broadcast_to_axis_of( | |
| arg1: tp.ArrayLike, | |
| arg2: tp.ArrayLike, | |
| axis: int, | |
| require_kwargs: tp.KwargsLike = None, | |
| ) -> tp.Array: | |
| """Broadcast `arg1` to an axis of `arg2`. | |
| If `arg2` has less dimensions than requested, will broadcast `arg1` to a single number. | |
| For other keyword arguments, see `broadcast`.""" | |
| if require_kwargs is None: | |
| require_kwargs = {} | |
| arg2 = to_any_array(arg2) | |
| if arg2.ndim < axis + 1: | |
| return broadcast_array_to(arg1, (1,))[0] # to a single number | |
| arg1 = broadcast_array_to(arg1, (arg2.shape[axis],)) | |
| arg1 = np.require(arg1, **require_kwargs) | |
| return arg1 | |
| def broadcast_combs( | |
| *objs: tp.ArrayLike, | |
| axis: int = 1, | |
| comb_func: tp.Callable = itertools.product, | |
| **broadcast_kwargs, | |
| ) -> tp.Any: | |
| """Align an axis of each array using a combinatoric function and broadcast their indexes. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.reshaping import broadcast_combs | |
| >>> df = pd.DataFrame([[1, 2, 3], [3, 4, 5]], columns=pd.Index(['a', 'b', 'c'], name='df_param')) | |
| >>> df2 = pd.DataFrame([[6, 7], [8, 9]], columns=pd.Index(['d', 'e'], name='df2_param')) | |
| >>> sr = pd.Series([10, 11], name='f') | |
| >>> new_df, new_df2, new_sr = broadcast_combs((df, df2, sr)) | |
| >>> new_df | |
| df_param a b c | |
| df2_param d e d e d e | |
| 0 1 1 2 2 3 3 | |
| 1 3 3 4 4 5 5 | |
| >>> new_df2 | |
| df_param a b c | |
| df2_param d e d e d e | |
| 0 6 7 6 7 6 7 | |
| 1 8 9 8 9 8 9 | |
| >>> new_sr | |
| df_param a b c | |
| df2_param d e d e d e | |
| 0 10 10 10 10 10 10 | |
| 1 11 11 11 11 11 11 | |
| ``` | |
| """ | |
| if broadcast_kwargs is None: | |
| broadcast_kwargs = {} | |
| objs = list(objs) | |
| if len(objs) < 2: | |
| raise ValueError("At least two arguments are required") | |
| for i in range(len(objs)): | |
| obj = to_any_array(objs[i]) | |
| if axis == 1: | |
| obj = to_2d(obj) | |
| objs[i] = obj | |
| indices = [] | |
| for obj in objs: | |
| indices.append(np.arange(len(indexes.get_index(to_pd_array(obj), axis)))) | |
| new_indices = list(map(list, zip(*list(comb_func(*indices))))) | |
| results = [] | |
| for i, obj in enumerate(objs): | |
| if axis == 1: | |
| if checks.is_pandas(obj): | |
| results.append(obj.iloc[:, new_indices[i]]) | |
| else: | |
| results.append(obj[:, new_indices[i]]) | |
| else: | |
| if checks.is_pandas(obj): | |
| results.append(obj.iloc[new_indices[i]]) | |
| else: | |
| results.append(obj[new_indices[i]]) | |
| if axis == 1: | |
| broadcast_kwargs = merge_dicts(dict(columns_from="stack"), broadcast_kwargs) | |
| else: | |
| broadcast_kwargs = merge_dicts(dict(index_from="stack"), broadcast_kwargs) | |
| return broadcast(*results, **broadcast_kwargs) | |
| def get_multiindex_series(obj: tp.SeriesFrame) -> tp.Series: | |
| """Get Series with a multi-index. | |
| If DataFrame has been passed, must at maximum have one row or column.""" | |
| checks.assert_instance_of(obj, (pd.Series, pd.DataFrame)) | |
| if checks.is_frame(obj): | |
| if obj.shape[0] == 1: | |
| obj = obj.iloc[0, :] | |
| elif obj.shape[1] == 1: | |
| obj = obj.iloc[:, 0] | |
| else: | |
| raise ValueError("Supported are either Series or DataFrame with one column/row") | |
| checks.assert_instance_of(obj.index, pd.MultiIndex) | |
| return obj | |
| def unstack_to_array( | |
| obj: tp.SeriesFrame, | |
| levels: tp.Optional[tp.MaybeLevelSequence] = None, | |
| sort: bool = True, | |
| return_indexes: bool = False, | |
| ) -> tp.Union[tp.Array, tp.Tuple[tp.Array, tp.List[tp.Index]]]: | |
| """Reshape `obj` based on its multi-index into a multi-dimensional array. | |
| Use `levels` to specify what index levels to unstack and in which order. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.reshaping import unstack_to_array | |
| >>> index = pd.MultiIndex.from_arrays( | |
| ... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']]) | |
| >>> sr = pd.Series([1, 2, 3, 4], index=index) | |
| >>> unstack_to_array(sr).shape | |
| (2, 2, 4) | |
| >>> unstack_to_array(sr) | |
| [[[ 1. nan nan nan] | |
| [nan 2. nan nan]] | |
| [[nan nan 3. nan] | |
| [nan nan nan 4.]]] | |
| >>> unstack_to_array(sr, levels=(2, 0)) | |
| [[ 1. nan] | |
| [ 2. nan] | |
| [nan 3.] | |
| [nan 4.]] | |
| ``` | |
| """ | |
| sr = get_multiindex_series(obj) | |
| if sr.index.duplicated().any(): | |
| raise ValueError("Index contains duplicate entries, cannot reshape") | |
| new_index_list = [] | |
| value_indices_list = [] | |
| if levels is None: | |
| levels = range(sr.index.nlevels) | |
| if isinstance(levels, (int, str)): | |
| levels = (levels,) | |
| for level in levels: | |
| level_values = indexes.select_levels(sr.index, level) | |
| new_index = level_values.unique() | |
| if sort: | |
| new_index = new_index.sort_values() | |
| new_index_list.append(new_index) | |
| index_map = pd.Series(range(len(new_index)), index=new_index) | |
| value_indices = index_map.loc[level_values] | |
| value_indices_list.append(value_indices) | |
| a = np.full(list(map(len, new_index_list)), np.nan) | |
| a[tuple(zip(value_indices_list))] = sr.values | |
| if return_indexes: | |
| return a, new_index_list | |
| return a | |
| def make_symmetric(obj: tp.SeriesFrame, sort: bool = True) -> tp.Frame: | |
| """Make `obj` symmetric. | |
| The index and columns of the resulting DataFrame will be identical. | |
| Requires the index and columns to have the same number of levels. | |
| Pass `sort=False` if index and columns should not be sorted, but concatenated | |
| and get duplicates removed. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.reshaping import make_symmetric | |
| >>> df = pd.DataFrame([[1, 2], [3, 4]], index=['a', 'b'], columns=['c', 'd']) | |
| >>> make_symmetric(df) | |
| a b c d | |
| a NaN NaN 1.0 2.0 | |
| b NaN NaN 3.0 4.0 | |
| c 1.0 3.0 NaN NaN | |
| d 2.0 4.0 NaN NaN | |
| ``` | |
| """ | |
| from vectorbtpro.base.merging import concat_arrays | |
| checks.assert_instance_of(obj, (pd.Series, pd.DataFrame)) | |
| df = to_2d(obj) | |
| if isinstance(df.index, pd.MultiIndex) or isinstance(df.columns, pd.MultiIndex): | |
| checks.assert_instance_of(df.index, pd.MultiIndex) | |
| checks.assert_instance_of(df.columns, pd.MultiIndex) | |
| checks.assert_array_equal(df.index.nlevels, df.columns.nlevels) | |
| names1, names2 = tuple(df.index.names), tuple(df.columns.names) | |
| else: | |
| names1, names2 = df.index.name, df.columns.name | |
| if names1 == names2: | |
| new_name = names1 | |
| else: | |
| if isinstance(df.index, pd.MultiIndex): | |
| new_name = tuple(zip(*[names1, names2])) | |
| else: | |
| new_name = (names1, names2) | |
| if sort: | |
| idx_vals = np.unique(concat_arrays((df.index, df.columns))).tolist() | |
| else: | |
| idx_vals = list(dict.fromkeys(concat_arrays((df.index, df.columns)))) | |
| df_index = df.index.copy() | |
| df_columns = df.columns.copy() | |
| if isinstance(df.index, pd.MultiIndex): | |
| unique_index = pd.MultiIndex.from_tuples(idx_vals, names=new_name) | |
| df_index.names = new_name | |
| df_columns.names = new_name | |
| else: | |
| unique_index = pd.Index(idx_vals, name=new_name) | |
| df_index.name = new_name | |
| df_columns.name = new_name | |
| df = df.copy(deep=False) | |
| df.index = df_index | |
| df.columns = df_columns | |
| df_out_dtype = np.promote_types(df.values.dtype, np.min_scalar_type(np.nan)) | |
| df_out = pd.DataFrame(index=unique_index, columns=unique_index, dtype=df_out_dtype) | |
| df_out.loc[:, :] = df | |
| df_out[df_out.isnull()] = df.transpose() | |
| return df_out | |
| def unstack_to_df( | |
| obj: tp.SeriesFrame, | |
| index_levels: tp.Optional[tp.MaybeLevelSequence] = None, | |
| column_levels: tp.Optional[tp.MaybeLevelSequence] = None, | |
| symmetric: bool = False, | |
| sort: bool = True, | |
| ) -> tp.Frame: | |
| """Reshape `obj` based on its multi-index into a DataFrame. | |
| Use `index_levels` to specify what index levels will form new index, and `column_levels` | |
| for new columns. Set `symmetric` to True to make DataFrame symmetric. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> from vectorbtpro.base.reshaping import unstack_to_df | |
| >>> index = pd.MultiIndex.from_arrays( | |
| ... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']], | |
| ... names=['x', 'y', 'z']) | |
| >>> sr = pd.Series([1, 2, 3, 4], index=index) | |
| >>> unstack_to_df(sr, index_levels=(0, 1), column_levels=2) | |
| z a b c d | |
| x y | |
| 1 3 1.0 NaN NaN NaN | |
| 1 4 NaN 2.0 NaN NaN | |
| 2 3 NaN NaN 3.0 NaN | |
| 2 4 NaN NaN NaN 4.0 | |
| ``` | |
| """ | |
| sr = get_multiindex_series(obj) | |
| if sr.index.nlevels > 2: | |
| if index_levels is None: | |
| raise ValueError("index_levels must be specified") | |
| if column_levels is None: | |
| raise ValueError("column_levels must be specified") | |
| else: | |
| if index_levels is None: | |
| index_levels = 0 | |
| if column_levels is None: | |
| column_levels = 1 | |
| unstacked, (new_index, new_columns) = unstack_to_array( | |
| sr, | |
| levels=(index_levels, column_levels), | |
| sort=sort, | |
| return_indexes=True, | |
| ) | |
| df = pd.DataFrame(unstacked, index=new_index, columns=new_columns) | |
| if symmetric: | |
| return make_symmetric(df, sort=sort) | |
| return df | |
| </file> | |
| <file path="base/wrapping.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Classes for wrapping NumPy arrays into Series/DataFrames.""" | |
| import numpy as np | |
| import pandas as pd | |
| from pandas.core.groupby import GroupBy as PandasGroupBy | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro._dtypes import * | |
| from vectorbtpro.base import indexes, reshaping | |
| from vectorbtpro.base.grouping.base import Grouper | |
| from vectorbtpro.base.indexes import stack_indexes, concat_indexes, IndexApplier | |
| from vectorbtpro.base.indexing import IndexingError, ExtPandasIndexer, index_dict, IdxSetter, IdxSetterFactory, IdxDict | |
| from vectorbtpro.base.resampling.base import Resampler | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.array_ import is_range, cast_to_min_precision, cast_to_max_precision | |
| from vectorbtpro.utils.attr_ import AttrResolverMixin, AttrResolverMixinT | |
| from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer | |
| from vectorbtpro.utils.config import Configured, merge_dicts, resolve_dict | |
| from vectorbtpro.utils.decorators import hybrid_method, cached_method, cached_property | |
| from vectorbtpro.utils.execution import Task, execute | |
| from vectorbtpro.utils.params import ItemParamable | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| from vectorbtpro.utils.warnings_ import warn | |
| if tp.TYPE_CHECKING: | |
| from vectorbtpro.base.accessors import BaseIDXAccessor as BaseIDXAccessorT | |
| from vectorbtpro.generic.splitting.base import Splitter as SplitterT | |
| else: | |
| BaseIDXAccessorT = "BaseIDXAccessor" | |
| SplitterT = "Splitter" | |
| __all__ = [ | |
| "ArrayWrapper", | |
| "Wrapping", | |
| ] | |
| HasWrapperT = tp.TypeVar("HasWrapperT", bound="HasWrapper") | |
| class HasWrapper(ExtPandasIndexer, ItemParamable): | |
| """Abstract class that manages a wrapper.""" | |
| @property | |
| def unwrapped(self) -> tp.Any: | |
| """Unwrapped object.""" | |
| raise NotImplementedError | |
| @hybrid_method | |
| def should_wrap(cls_or_self) -> bool: | |
| """Whether to wrap where applicable.""" | |
| return True | |
| @property | |
| def wrapper(self) -> "ArrayWrapper": | |
| """Array wrapper of the type `ArrayWrapper`.""" | |
| raise NotImplementedError | |
| @property | |
| def column_only_select(self) -> bool: | |
| """Whether to perform indexing on columns only.""" | |
| raise NotImplementedError | |
| @property | |
| def range_only_select(self) -> bool: | |
| """Whether to perform indexing on rows using slices only.""" | |
| raise NotImplementedError | |
| @property | |
| def group_select(self) -> bool: | |
| """Whether to allow indexing on groups.""" | |
| raise NotImplementedError | |
| def regroup(self: HasWrapperT, group_by: tp.GroupByLike, **kwargs) -> HasWrapperT: | |
| """Regroup this instance.""" | |
| raise NotImplementedError | |
| def ungroup(self: HasWrapperT, **kwargs) -> HasWrapperT: | |
| """Ungroup this instance.""" | |
| return self.regroup(False, **kwargs) | |
| # ############# Selection ############# # | |
| def select_col( | |
| self: HasWrapperT, | |
| column: tp.Any = None, | |
| group_by: tp.GroupByLike = None, | |
| **kwargs, | |
| ) -> HasWrapperT: | |
| """Select one column/group. | |
| `column` can be a label-based position as well as an integer position (if label fails).""" | |
| _self = self.regroup(group_by, **kwargs) | |
| def _check_out_dim(out: HasWrapperT) -> HasWrapperT: | |
| if out.wrapper.get_ndim() == 2: | |
| if out.wrapper.get_shape_2d()[1] == 1: | |
| if out.column_only_select: | |
| return out.iloc[0] | |
| return out.iloc[:, 0] | |
| if _self.wrapper.grouper.is_grouped(): | |
| raise TypeError("Could not select one group: multiple groups returned") | |
| else: | |
| raise TypeError("Could not select one column: multiple columns returned") | |
| return out | |
| if column is None: | |
| if _self.wrapper.get_ndim() == 2 and _self.wrapper.get_shape_2d()[1] == 1: | |
| column = 0 | |
| if column is not None: | |
| if _self.wrapper.grouper.is_grouped(): | |
| if _self.wrapper.grouped_ndim == 1: | |
| raise TypeError("This instance already contains one group of data") | |
| if column not in _self.wrapper.get_columns(): | |
| if isinstance(column, int): | |
| if _self.column_only_select: | |
| return _check_out_dim(_self.iloc[column]) | |
| return _check_out_dim(_self.iloc[:, column]) | |
| raise KeyError(f"Group '{column}' not found") | |
| else: | |
| if _self.wrapper.ndim == 1: | |
| raise TypeError("This instance already contains one column of data") | |
| if column not in _self.wrapper.columns: | |
| if isinstance(column, int): | |
| if _self.column_only_select: | |
| return _check_out_dim(_self.iloc[column]) | |
| return _check_out_dim(_self.iloc[:, column]) | |
| raise KeyError(f"Column '{column}' not found") | |
| return _check_out_dim(_self[column]) | |
| if _self.wrapper.grouper.is_grouped(): | |
| if _self.wrapper.grouped_ndim == 1: | |
| return _self | |
| raise TypeError("Only one group is allowed. Use indexing or column argument.") | |
| if _self.wrapper.ndim == 1: | |
| return _self | |
| raise TypeError("Only one column is allowed. Use indexing or column argument.") | |
| @hybrid_method | |
| def select_col_from_obj( | |
| cls_or_self, | |
| obj: tp.Optional[tp.SeriesFrame], | |
| column: tp.Any = None, | |
| obj_ungrouped: bool = False, | |
| group_by: tp.GroupByLike = None, | |
| wrapper: tp.Optional["ArrayWrapper"] = None, | |
| **kwargs, | |
| ) -> tp.MaybeSeries: | |
| """Select one column/group from a Pandas object. | |
| `column` can be a label-based position as well as an integer position (if label fails).""" | |
| if obj is None: | |
| return None | |
| if not isinstance(cls_or_self, type): | |
| if wrapper is None: | |
| wrapper = cls_or_self.wrapper | |
| else: | |
| checks.assert_not_none(wrapper, arg_name="wrapper") | |
| _wrapper = wrapper.regroup(group_by, **kwargs) | |
| def _check_out_dim(out: tp.SeriesFrame, from_df: bool) -> tp.Series: | |
| bad_shape = False | |
| if from_df and isinstance(out, pd.DataFrame): | |
| if len(out.columns) == 1: | |
| return out.iloc[:, 0] | |
| bad_shape = True | |
| if not from_df and isinstance(out, pd.Series): | |
| if len(out) == 1: | |
| return out.iloc[0] | |
| bad_shape = True | |
| if bad_shape: | |
| if _wrapper.grouper.is_grouped(): | |
| raise TypeError("Could not select one group: multiple groups returned") | |
| else: | |
| raise TypeError("Could not select one column: multiple columns returned") | |
| return out | |
| if column is None: | |
| if _wrapper.get_ndim() == 2 and _wrapper.get_shape_2d()[1] == 1: | |
| column = 0 | |
| if column is not None: | |
| if _wrapper.grouper.is_grouped(): | |
| if _wrapper.grouped_ndim == 1: | |
| raise TypeError("This instance already contains one group of data") | |
| if obj_ungrouped: | |
| mask = _wrapper.grouper.group_by == column | |
| if not mask.any(): | |
| raise KeyError(f"Group '{column}' not found") | |
| if isinstance(obj, pd.DataFrame): | |
| return obj.loc[:, mask] | |
| return obj.loc[mask] | |
| else: | |
| if column not in _wrapper.get_columns(): | |
| if isinstance(column, int): | |
| if isinstance(obj, pd.DataFrame): | |
| return _check_out_dim(obj.iloc[:, column], True) | |
| return _check_out_dim(obj.iloc[column], False) | |
| raise KeyError(f"Group '{column}' not found") | |
| else: | |
| if _wrapper.ndim == 1: | |
| raise TypeError("This instance already contains one column of data") | |
| if column not in _wrapper.columns: | |
| if isinstance(column, int): | |
| if isinstance(obj, pd.DataFrame): | |
| return _check_out_dim(obj.iloc[:, column], True) | |
| return _check_out_dim(obj.iloc[column], False) | |
| raise KeyError(f"Column '{column}' not found") | |
| if isinstance(obj, pd.DataFrame): | |
| return _check_out_dim(obj[column], True) | |
| return _check_out_dim(obj[column], False) | |
| if not _wrapper.grouper.is_grouped(): | |
| if _wrapper.ndim == 1: | |
| return obj | |
| raise TypeError("Only one column is allowed. Use indexing or column argument.") | |
| if _wrapper.grouped_ndim == 1: | |
| return obj | |
| raise TypeError("Only one group is allowed. Use indexing or column argument.") | |
| # ############# Splitting ############# # | |
| def split( | |
| self, | |
| *args, | |
| splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, | |
| wrap: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> tp.Any: | |
| """Split this instance. | |
| Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_take`.""" | |
| from vectorbtpro.generic.splitting.base import Splitter | |
| if splitter_cls is None: | |
| splitter_cls = Splitter | |
| if wrap is None: | |
| wrap = self.should_wrap() | |
| wrapped_self = self if wrap else self.unwrapped | |
| return splitter_cls.split_and_take(self.wrapper.index, wrapped_self, *args, **kwargs) | |
| def split_apply( | |
| self, | |
| apply_func: tp.Union[str, tp.Callable], | |
| *args, | |
| splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, | |
| wrap: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> tp.Any: | |
| """Split this instance and apply a function to each split. | |
| Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`.""" | |
| from vectorbtpro.generic.splitting.base import Splitter, Takeable | |
| if isinstance(apply_func, str): | |
| apply_func = getattr(type(self), apply_func) | |
| if splitter_cls is None: | |
| splitter_cls = Splitter | |
| if wrap is None: | |
| wrap = self.should_wrap() | |
| wrapped_self = self if wrap else self.unwrapped | |
| return splitter_cls.split_and_apply(self.wrapper.index, apply_func, Takeable(wrapped_self), *args, **kwargs) | |
| # ############# Chunking ############# # | |
| def chunk( | |
| self: HasWrapperT, | |
| axis: tp.Optional[int] = None, | |
| min_size: tp.Optional[int] = None, | |
| n_chunks: tp.Union[None, int, str] = None, | |
| chunk_len: tp.Union[None, int, str] = None, | |
| chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None, | |
| select: bool = False, | |
| wrap: tp.Optional[bool] = None, | |
| return_chunk_meta: bool = False, | |
| ) -> tp.Iterator[tp.Union[HasWrapperT, tp.Tuple[ChunkMeta, HasWrapperT]]]: | |
| """Chunk this instance. | |
| If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise. | |
| For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`.""" | |
| if axis is None: | |
| axis = 0 if self.wrapper.ndim == 1 else 1 | |
| if self.wrapper.ndim == 1 and axis == 1: | |
| raise TypeError("Axis 1 is not supported for one dimension") | |
| checks.assert_in(axis, (0, 1)) | |
| size = self.wrapper.shape_2d[axis] | |
| if wrap is None: | |
| wrap = self.should_wrap() | |
| wrapped_self = self if wrap else self.unwrapped | |
| if chunk_meta is None: | |
| chunk_meta = iter_chunk_meta( | |
| size=size, | |
| min_size=min_size, | |
| n_chunks=n_chunks, | |
| chunk_len=chunk_len, | |
| ) | |
| for _chunk_meta in chunk_meta: | |
| if select: | |
| array_taker = ArraySelector(axis=axis) | |
| else: | |
| array_taker = ArraySlicer(axis=axis) | |
| if return_chunk_meta: | |
| yield _chunk_meta, array_taker.take(wrapped_self, _chunk_meta) | |
| else: | |
| yield array_taker.take(wrapped_self, _chunk_meta) | |
| def chunk_apply( | |
| self: HasWrapperT, | |
| apply_func: tp.Union[str, tp.Callable], | |
| *args, | |
| chunk_kwargs: tp.KwargsLike = None, | |
| execute_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.MergeableResults: | |
| """Chunk this instance and apply a function to each chunk. | |
| If `apply_func` is a string, becomes the method name. | |
| For arguments related to chunking, see `Wrapping.chunk`.""" | |
| if isinstance(apply_func, str): | |
| apply_func = getattr(type(self), apply_func) | |
| if chunk_kwargs is None: | |
| chunk_arg_names = set(get_func_arg_names(self.chunk)) | |
| chunk_kwargs = {} | |
| for k in list(kwargs.keys()): | |
| if k in chunk_arg_names: | |
| chunk_kwargs[k] = kwargs.pop(k) | |
| if execute_kwargs is None: | |
| execute_kwargs = {} | |
| chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs) | |
| tasks = [] | |
| keys = [] | |
| for _chunk_meta, chunk in chunks: | |
| tasks.append(Task(apply_func, chunk, *args, **kwargs)) | |
| keys.append(get_chunk_meta_key(_chunk_meta)) | |
| keys = pd.Index(keys, name="chunk_indices") | |
| return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs) | |
| # ############# Iteration ############# # | |
| def get_item_keys(self, group_by: tp.GroupByLike = None) -> tp.Index: | |
| """Get keys for `Wrapping.items`.""" | |
| _self = self.regroup(group_by=group_by) | |
| if _self.group_select and _self.wrapper.grouper.is_grouped(): | |
| return _self.wrapper.get_columns() | |
| return _self.wrapper.columns | |
| def items( | |
| self, | |
| group_by: tp.GroupByLike = None, | |
| apply_group_by: bool = False, | |
| keep_2d: bool = False, | |
| key_as_index: bool = False, | |
| wrap: tp.Optional[bool] = None, | |
| ) -> tp.Items: | |
| """Iterate over columns or groups (if grouped and `Wrapping.group_select` is True). | |
| If `apply_group_by` is False, `group_by` becomes a grouping instruction for the iteration, | |
| not for the final object. In this case, will raise an error if the instance is grouped | |
| and that grouping must be changed.""" | |
| if wrap is None: | |
| wrap = self.should_wrap() | |
| def _resolve_v(self): | |
| return self if wrap else self.unwrapped | |
| if group_by is None or apply_group_by: | |
| _self = self.regroup(group_by=group_by) | |
| if _self.group_select and _self.wrapper.grouper.is_grouped(): | |
| columns = _self.wrapper.get_columns() | |
| ndim = _self.wrapper.get_ndim() | |
| else: | |
| columns = _self.wrapper.columns | |
| ndim = _self.wrapper.ndim | |
| if ndim == 1: | |
| if key_as_index: | |
| yield columns, _resolve_v(_self) | |
| else: | |
| yield columns[0], _resolve_v(_self) | |
| else: | |
| for i in range(len(columns)): | |
| if key_as_index: | |
| key = columns[[i]] | |
| else: | |
| key = columns[i] | |
| if _self.column_only_select: | |
| if keep_2d: | |
| yield key, _resolve_v(_self.iloc[i : i + 1]) | |
| else: | |
| yield key, _resolve_v(_self.iloc[i]) | |
| else: | |
| if keep_2d: | |
| yield key, _resolve_v(_self.iloc[:, i : i + 1]) | |
| else: | |
| yield key, _resolve_v(_self.iloc[:, i]) | |
| else: | |
| if self.group_select and self.wrapper.grouper.is_grouped(): | |
| raise ValueError("Cannot change grouping") | |
| wrapper = self.wrapper.regroup(group_by=group_by) | |
| if wrapper.get_ndim() == 1: | |
| if key_as_index: | |
| yield wrapper.get_columns(), _resolve_v(self) | |
| else: | |
| yield wrapper.get_columns()[0], _resolve_v(self) | |
| else: | |
| for group, group_idxs in wrapper.grouper.iter_groups(key_as_index=key_as_index): | |
| if self.column_only_select: | |
| if keep_2d or len(group_idxs) > 1: | |
| yield group, _resolve_v(self.iloc[group_idxs]) | |
| else: | |
| yield group, _resolve_v(self.iloc[group_idxs[0]]) | |
| else: | |
| if keep_2d or len(group_idxs) > 1: | |
| yield group, _resolve_v(self.iloc[:, group_idxs]) | |
| else: | |
| yield group, _resolve_v(self.iloc[:, group_idxs[0]]) | |
| ArrayWrapperT = tp.TypeVar("ArrayWrapperT", bound="ArrayWrapper") | |
| class ArrayWrapper(Configured, HasWrapper, IndexApplier): | |
| """Class that stores index, columns, and shape metadata for wrapping NumPy arrays. | |
| Tightly integrated with `vectorbtpro.base.grouping.base.Grouper` for grouping columns. | |
| If the underlying object is a Series, pass `[sr.name]` as `columns`. | |
| `**kwargs` are passed to `vectorbtpro.base.grouping.base.Grouper`. | |
| !!! note | |
| This class is meant to be immutable. To change any attribute, use `ArrayWrapper.replace`. | |
| Use methods that begin with `get_` to get group-aware results.""" | |
| @classmethod | |
| def from_obj(cls: tp.Type[ArrayWrapperT], obj: tp.ArrayLike, **kwargs) -> ArrayWrapperT: | |
| """Derive metadata from an object.""" | |
| from vectorbtpro.base.reshaping import to_pd_array | |
| from vectorbtpro.data.base import Data | |
| if isinstance(obj, Data): | |
| obj = obj.symbol_wrapper | |
| if isinstance(obj, Wrapping): | |
| obj = obj.wrapper | |
| if isinstance(obj, ArrayWrapper): | |
| return obj.replace(**kwargs) | |
| pd_obj = to_pd_array(obj) | |
| index = indexes.get_index(pd_obj, 0) | |
| columns = indexes.get_index(pd_obj, 1) | |
| ndim = pd_obj.ndim | |
| kwargs.pop("index", None) | |
| kwargs.pop("columns", None) | |
| kwargs.pop("ndim", None) | |
| return cls(index, columns, ndim, **kwargs) | |
| @classmethod | |
| def from_shape( | |
| cls: tp.Type[ArrayWrapperT], | |
| shape: tp.ShapeLike, | |
| index: tp.Optional[tp.IndexLike] = None, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| ndim: tp.Optional[int] = None, | |
| *args, | |
| **kwargs, | |
| ) -> ArrayWrapperT: | |
| """Derive metadata from shape.""" | |
| shape = reshaping.to_tuple_shape(shape) | |
| if index is None: | |
| index = pd.RangeIndex(stop=shape[0]) | |
| if columns is None: | |
| columns = pd.RangeIndex(stop=shape[1] if len(shape) > 1 else 1) | |
| if ndim is None: | |
| ndim = len(shape) | |
| return cls(index, columns, ndim, *args, **kwargs) | |
| @staticmethod | |
| def extract_init_kwargs(**kwargs) -> tp.Tuple[tp.Kwargs, tp.Kwargs]: | |
| """Extract keyword arguments that can be passed to `ArrayWrapper` or `Grouper`.""" | |
| wrapper_arg_names = get_func_arg_names(ArrayWrapper.__init__) | |
| grouper_arg_names = get_func_arg_names(Grouper.__init__) | |
| init_kwargs = dict() | |
| for k in list(kwargs.keys()): | |
| if k in wrapper_arg_names or k in grouper_arg_names: | |
| init_kwargs[k] = kwargs.pop(k) | |
| return init_kwargs, kwargs | |
| @classmethod | |
| def resolve_stack_kwargs(cls, *wrappers: tp.MaybeTuple[ArrayWrapperT], **kwargs) -> tp.Kwargs: | |
| """Resolve keyword arguments for initializing `ArrayWrapper` after stacking.""" | |
| if len(wrappers) == 1: | |
| wrappers = wrappers[0] | |
| wrappers = list(wrappers) | |
| common_keys = set() | |
| for wrapper in wrappers: | |
| common_keys = common_keys.union(set(wrapper.config.keys())) | |
| if "grouper" not in kwargs: | |
| common_keys = common_keys.union(set(wrapper.grouper.config.keys())) | |
| common_keys.remove("grouper") | |
| init_wrapper = wrappers[0] | |
| for i in range(1, len(wrappers)): | |
| wrapper = wrappers[i] | |
| for k in common_keys: | |
| if k not in kwargs: | |
| same_k = True | |
| try: | |
| if k in wrapper.config: | |
| if not checks.is_deep_equal(init_wrapper.config[k], wrapper.config[k]): | |
| same_k = False | |
| elif "grouper" not in kwargs and k in wrapper.grouper.config: | |
| if not checks.is_deep_equal(init_wrapper.grouper.config[k], wrapper.grouper.config[k]): | |
| same_k = False | |
| else: | |
| same_k = False | |
| except KeyError as e: | |
| same_k = False | |
| if not same_k: | |
| raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.") | |
| for k in common_keys: | |
| if k not in kwargs: | |
| if k in init_wrapper.config: | |
| kwargs[k] = init_wrapper.config[k] | |
| elif "grouper" not in kwargs and k in init_wrapper.grouper.config: | |
| kwargs[k] = init_wrapper.grouper.config[k] | |
| else: | |
| raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.") | |
| return kwargs | |
| @hybrid_method | |
| def row_stack( | |
| cls_or_self: tp.MaybeType[ArrayWrapperT], | |
| *wrappers: tp.MaybeTuple[ArrayWrapperT], | |
| index: tp.Optional[tp.IndexLike] = None, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| group_by: tp.GroupByLike = None, | |
| stack_columns: bool = True, | |
| index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append", | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| clean_index_kwargs: tp.KwargsLike = None, | |
| verify_integrity: bool = True, | |
| **kwargs, | |
| ) -> ArrayWrapperT: | |
| """Stack multiple `ArrayWrapper` instances along rows. | |
| Concatenates indexes using `vectorbtpro.base.indexes.concat_indexes`. | |
| Frequency must be the same across all indexes. A custom frequency can be provided via `freq`. | |
| If column levels in some instances differ, they will be stacked upon each other. | |
| Custom columns can be provided via `columns`. | |
| If `group_by` is None, all instances must be either grouped or not, and they must | |
| contain the same group values and labels. | |
| All instances must contain the same keys and values in their configs and configs of their | |
| grouper instances, apart from those arguments provided explicitly via `kwargs`.""" | |
| if not isinstance(cls_or_self, type): | |
| wrappers = (cls_or_self, *wrappers) | |
| cls = type(cls_or_self) | |
| else: | |
| cls = cls_or_self | |
| if len(wrappers) == 1: | |
| wrappers = wrappers[0] | |
| wrappers = list(wrappers) | |
| for wrapper in wrappers: | |
| if not checks.is_instance_of(wrapper, ArrayWrapper): | |
| raise TypeError("Each object to be merged must be an instance of ArrayWrapper") | |
| if index is None: | |
| index = concat_indexes( | |
| [wrapper.index for wrapper in wrappers], | |
| index_concat_method=index_concat_method, | |
| keys=keys, | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=verify_integrity, | |
| axis=0, | |
| ) | |
| elif not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| kwargs["index"] = index | |
| if freq is None: | |
| new_freq = None | |
| for wrapper in wrappers: | |
| if new_freq is None: | |
| new_freq = wrapper.freq | |
| else: | |
| if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq: | |
| raise ValueError("Objects to be merged must have the same frequency") | |
| freq = new_freq | |
| kwargs["freq"] = freq | |
| if columns is None: | |
| new_columns = None | |
| for wrapper in wrappers: | |
| if new_columns is None: | |
| new_columns = wrapper.columns | |
| else: | |
| if not checks.is_index_equal(new_columns, wrapper.columns): | |
| if not stack_columns: | |
| raise ValueError("Objects to be merged must have the same columns") | |
| new_columns = stack_indexes( | |
| (new_columns, wrapper.columns), | |
| **resolve_dict(clean_index_kwargs), | |
| ) | |
| columns = new_columns | |
| elif not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| kwargs["columns"] = columns | |
| if "grouper" in kwargs: | |
| if not checks.is_index_equal(columns, kwargs["grouper"].index): | |
| raise ValueError("Columns and grouper index must match") | |
| if group_by is not None: | |
| kwargs["group_by"] = group_by | |
| else: | |
| if group_by is None: | |
| grouped = None | |
| for wrapper in wrappers: | |
| wrapper_grouped = wrapper.grouper.is_grouped() | |
| if grouped is None: | |
| grouped = wrapper_grouped | |
| else: | |
| if grouped is not wrapper_grouped: | |
| raise ValueError("Objects to be merged must be either grouped or not") | |
| if grouped: | |
| new_group_by = None | |
| for wrapper in wrappers: | |
| wrapper_groups, wrapper_grouped_index = wrapper.grouper.get_groups_and_index() | |
| wrapper_group_by = wrapper_grouped_index[wrapper_groups] | |
| if new_group_by is None: | |
| new_group_by = wrapper_group_by | |
| else: | |
| if not checks.is_index_equal(new_group_by, wrapper_group_by): | |
| raise ValueError("Objects to be merged must have the same groups") | |
| group_by = new_group_by | |
| else: | |
| group_by = False | |
| kwargs["group_by"] = group_by | |
| if "ndim" not in kwargs: | |
| ndim = None | |
| for wrapper in wrappers: | |
| if ndim is None or wrapper.ndim > 1: | |
| ndim = wrapper.ndim | |
| kwargs["ndim"] = ndim | |
| return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs)) | |
| @hybrid_method | |
| def column_stack( | |
| cls_or_self: tp.MaybeType[ArrayWrapperT], | |
| *wrappers: tp.MaybeTuple[ArrayWrapperT], | |
| index: tp.Optional[tp.IndexLike] = None, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| group_by: tp.GroupByLike = None, | |
| union_index: bool = True, | |
| col_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append", | |
| group_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = ("append", "factorize_each"), | |
| keys: tp.Optional[tp.IndexLike] = None, | |
| clean_index_kwargs: tp.KwargsLike = None, | |
| verify_integrity: bool = True, | |
| **kwargs, | |
| ) -> ArrayWrapperT: | |
| """Stack multiple `ArrayWrapper` instances along columns. | |
| If indexes are the same in each wrapper index, will use that index. If indexes differ and | |
| `union_index` is True, they will be merged into a single one by the set union operation. | |
| Otherwise, an error will be raised. The merged index must have no duplicates or mixed data, | |
| and must be monotonically increasing. A custom index can be provided via `index`. | |
| Frequency must be the same across all indexes. A custom frequency can be provided via `freq`. | |
| Concatenates columns and groups using `vectorbtpro.base.indexes.concat_indexes`. | |
| If any of the instances has `column_only_select` being enabled, the final wrapper will also enable it. | |
| If any of the instances has `group_select` or other grouping-related flags being disabled, the final | |
| wrapper will also disable them. | |
| All instances must contain the same keys and values in their configs and configs of their | |
| grouper instances, apart from those arguments provided explicitly via `kwargs`.""" | |
| if not isinstance(cls_or_self, type): | |
| wrappers = (cls_or_self, *wrappers) | |
| cls = type(cls_or_self) | |
| else: | |
| cls = cls_or_self | |
| if len(wrappers) == 1: | |
| wrappers = wrappers[0] | |
| wrappers = list(wrappers) | |
| for wrapper in wrappers: | |
| if not checks.is_instance_of(wrapper, ArrayWrapper): | |
| raise TypeError("Each object to be merged must be an instance of ArrayWrapper") | |
| for wrapper in wrappers: | |
| if wrapper.index.has_duplicates: | |
| raise ValueError("Index of some objects to be merged contains duplicates") | |
| if index is None: | |
| new_index = None | |
| for wrapper in wrappers: | |
| if new_index is None: | |
| new_index = wrapper.index | |
| else: | |
| if not checks.is_index_equal(new_index, wrapper.index): | |
| if not union_index: | |
| raise ValueError( | |
| "Objects to be merged must have the same index. " | |
| "Use union_index=True to merge index as well." | |
| ) | |
| else: | |
| if new_index.dtype != wrapper.index.dtype: | |
| raise ValueError("Indexes to be merged must have the same data type") | |
| new_index = new_index.union(wrapper.index) | |
| if not new_index.is_monotonic_increasing: | |
| raise ValueError("Merged index must be monotonically increasing") | |
| index = new_index | |
| elif not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| kwargs["index"] = index | |
| if freq is None: | |
| new_freq = None | |
| for wrapper in wrappers: | |
| if new_freq is None: | |
| new_freq = wrapper.freq | |
| else: | |
| if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq: | |
| raise ValueError("Objects to be merged must have the same frequency") | |
| freq = new_freq | |
| kwargs["freq"] = freq | |
| if columns is None: | |
| columns = concat_indexes( | |
| [wrapper.columns for wrapper in wrappers], | |
| index_concat_method=col_concat_method, | |
| keys=keys, | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=verify_integrity, | |
| axis=1, | |
| ) | |
| elif not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| kwargs["columns"] = columns | |
| if "grouper" in kwargs: | |
| if not checks.is_index_equal(columns, kwargs["grouper"].index): | |
| raise ValueError("Columns and grouper index must match") | |
| if group_by is not None: | |
| kwargs["group_by"] = group_by | |
| else: | |
| if group_by is None: | |
| any_grouped = False | |
| for wrapper in wrappers: | |
| if wrapper.grouper.is_grouped(): | |
| any_grouped = True | |
| break | |
| if any_grouped: | |
| group_by = concat_indexes( | |
| [wrapper.grouper.get_stretched_index() for wrapper in wrappers], | |
| index_concat_method=group_concat_method, | |
| keys=keys, | |
| clean_index_kwargs=clean_index_kwargs, | |
| verify_integrity=verify_integrity, | |
| axis=2, | |
| ) | |
| else: | |
| group_by = False | |
| kwargs["group_by"] = group_by | |
| if "ndim" not in kwargs: | |
| kwargs["ndim"] = 2 | |
| if "grouped_ndim" not in kwargs: | |
| kwargs["grouped_ndim"] = None | |
| if "column_only_select" not in kwargs: | |
| column_only_select = None | |
| for wrapper in wrappers: | |
| if column_only_select is None or wrapper.column_only_select: | |
| column_only_select = wrapper.column_only_select | |
| kwargs["column_only_select"] = column_only_select | |
| if "range_only_select" not in kwargs: | |
| range_only_select = None | |
| for wrapper in wrappers: | |
| if range_only_select is None or wrapper.range_only_select: | |
| range_only_select = wrapper.range_only_select | |
| kwargs["range_only_select"] = range_only_select | |
| if "group_select" not in kwargs: | |
| group_select = None | |
| for wrapper in wrappers: | |
| if group_select is None or not wrapper.group_select: | |
| group_select = wrapper.group_select | |
| kwargs["group_select"] = group_select | |
| if "grouper" not in kwargs: | |
| if "allow_enable" not in kwargs: | |
| allow_enable = None | |
| for wrapper in wrappers: | |
| if allow_enable is None or not wrapper.grouper.allow_enable: | |
| allow_enable = wrapper.grouper.allow_enable | |
| kwargs["allow_enable"] = allow_enable | |
| if "allow_disable" not in kwargs: | |
| allow_disable = None | |
| for wrapper in wrappers: | |
| if allow_disable is None or not wrapper.grouper.allow_disable: | |
| allow_disable = wrapper.grouper.allow_disable | |
| kwargs["allow_disable"] = allow_disable | |
| if "allow_modify" not in kwargs: | |
| allow_modify = None | |
| for wrapper in wrappers: | |
| if allow_modify is None or not wrapper.grouper.allow_modify: | |
| allow_modify = wrapper.grouper.allow_modify | |
| kwargs["allow_modify"] = allow_modify | |
| return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs)) | |
| def __init__( | |
| self, | |
| index: tp.IndexLike, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| ndim: tp.Optional[int] = None, | |
| freq: tp.Optional[tp.FrequencyLike] = None, | |
| parse_index: tp.Optional[bool] = None, | |
| column_only_select: tp.Optional[bool] = None, | |
| range_only_select: tp.Optional[bool] = None, | |
| group_select: tp.Optional[bool] = None, | |
| grouped_ndim: tp.Optional[int] = None, | |
| grouper: tp.Optional[Grouper] = None, | |
| **kwargs, | |
| ) -> None: | |
| checks.assert_not_none(index, arg_name="index") | |
| index = dt.prepare_dt_index(index, parse_index=parse_index) | |
| if columns is None: | |
| columns = [None] | |
| if not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| if ndim is None: | |
| if len(columns) == 1 and not isinstance(columns, pd.MultiIndex): | |
| ndim = 1 | |
| else: | |
| ndim = 2 | |
| else: | |
| if len(columns) > 1: | |
| ndim = 2 | |
| grouper_arg_names = get_func_arg_names(Grouper.__init__) | |
| grouper_kwargs = dict() | |
| for k in list(kwargs.keys()): | |
| if k in grouper_arg_names: | |
| grouper_kwargs[k] = kwargs.pop(k) | |
| if grouper is None: | |
| grouper = Grouper(columns, **grouper_kwargs) | |
| elif not checks.is_index_equal(columns, grouper.index) or len(grouper_kwargs) > 0: | |
| grouper = grouper.replace(index=columns, **grouper_kwargs) | |
| HasWrapper.__init__(self) | |
| Configured.__init__( | |
| self, | |
| index=index, | |
| columns=columns, | |
| ndim=ndim, | |
| freq=freq, | |
| parse_index=parse_index, | |
| column_only_select=column_only_select, | |
| range_only_select=range_only_select, | |
| group_select=group_select, | |
| grouped_ndim=grouped_ndim, | |
| grouper=grouper, | |
| **kwargs, | |
| ) | |
| self._index = index | |
| self._columns = columns | |
| self._ndim = ndim | |
| self._freq = freq | |
| self._parse_index = parse_index | |
| self._column_only_select = column_only_select | |
| self._range_only_select = range_only_select | |
| self._group_select = group_select | |
| self._grouper = grouper | |
| self._grouped_ndim = grouped_ndim | |
| def indexing_func_meta( | |
| self: ArrayWrapperT, | |
| pd_indexing_func: tp.PandasIndexingFunc, | |
| index: tp.Optional[tp.IndexLike] = None, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| column_only_select: tp.Optional[bool] = None, | |
| range_only_select: tp.Optional[bool] = None, | |
| group_select: tp.Optional[bool] = None, | |
| return_slices: bool = True, | |
| return_none_slices: bool = True, | |
| return_scalars: bool = True, | |
| group_by: tp.GroupByLike = None, | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| ) -> dict: | |
| """Perform indexing on `ArrayWrapper` and also return metadata. | |
| Takes into account column grouping. | |
| Flipping rows and columns is not allowed. If one row is selected, the result will still be | |
| a Series when indexing a Series and a DataFrame when indexing a DataFrame. | |
| Set `column_only_select` to True to index the array wrapper as a Series of columns/groups. | |
| This way, selection of index (axis 0) can be avoided. Set `range_only_select` to True to | |
| allow selection of rows only using slices. Set `group_select` to True to allow selection of groups. | |
| Otherwise, indexing is performed on columns, even if grouping is enabled. Takes effect only if | |
| grouping is enabled. | |
| Returns the new array wrapper, row indices, column indices, and group indices. | |
| If `return_slices` is True (default), indices will be returned as a slice if they were | |
| identified as a range. If `return_none_slices` is True (default), indices will be returned as a slice | |
| `(None, None, None)` if the axis hasn't been changed. | |
| !!! note | |
| If `column_only_select` is True, make sure to index the array wrapper | |
| as a Series of columns rather than a DataFrame. For example, the operation | |
| `.iloc[:, :2]` should become `.iloc[:2]`. Operations are not allowed if the | |
| object is already a Series and thus has only one column/group.""" | |
| if column_only_select is None: | |
| column_only_select = self.column_only_select | |
| if range_only_select is None: | |
| range_only_select = self.range_only_select | |
| if group_select is None: | |
| group_select = self.group_select | |
| if wrapper_kwargs is None: | |
| wrapper_kwargs = {} | |
| _self = self.regroup(group_by) | |
| group_select = group_select and _self.grouper.is_grouped() | |
| if index is None: | |
| index = _self.index | |
| if not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| if columns is None: | |
| if group_select: | |
| columns = _self.get_columns() | |
| else: | |
| columns = _self.columns | |
| if not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| if group_select: | |
| # Groups as columns | |
| i_wrapper = ArrayWrapper(index, columns, _self.get_ndim()) | |
| else: | |
| # Columns as columns | |
| i_wrapper = ArrayWrapper(index, columns, _self.ndim) | |
| n_rows = len(index) | |
| n_cols = len(columns) | |
| def _resolve_arr(arr, n): | |
| if checks.is_np_array(arr) and is_range(arr): | |
| if arr[0] == 0 and arr[-1] == n - 1: | |
| if return_none_slices: | |
| return slice(None, None, None), False | |
| return arr, False | |
| if return_slices: | |
| return slice(arr[0], arr[-1] + 1, None), True | |
| return arr, True | |
| if isinstance(arr, np.integer): | |
| arr = arr.item() | |
| columns_changed = True | |
| if isinstance(arr, int): | |
| if arr == 0 and n == 1: | |
| columns_changed = False | |
| if not return_scalars: | |
| arr = np.array([arr]) | |
| return arr, columns_changed | |
| if column_only_select: | |
| if i_wrapper.ndim == 1: | |
| raise IndexingError("Columns only: This instance already contains one column of data") | |
| try: | |
| col_mapper = pd_indexing_func(i_wrapper.wrap_reduced(np.arange(n_cols), columns=columns)) | |
| except pd.core.indexing.IndexingError as e: | |
| warn("Columns only: Make sure to treat this instance as a Series of columns rather than a DataFrame") | |
| raise e | |
| if checks.is_series(col_mapper): | |
| new_columns = col_mapper.index | |
| col_idxs = col_mapper.values | |
| new_ndim = 2 | |
| else: | |
| new_columns = columns[[col_mapper]] | |
| col_idxs = col_mapper | |
| new_ndim = 1 | |
| new_index = index | |
| row_idxs = np.arange(len(index)) | |
| else: | |
| init_row_mapper_values = reshaping.broadcast_array_to(np.arange(n_rows)[:, None], (n_rows, n_cols)) | |
| init_row_mapper = i_wrapper.wrap(init_row_mapper_values, index=index, columns=columns) | |
| row_mapper = pd_indexing_func(init_row_mapper) | |
| if i_wrapper.ndim == 1: | |
| if not checks.is_series(row_mapper): | |
| row_idxs = np.array([row_mapper]) | |
| new_index = index[row_idxs] | |
| else: | |
| row_idxs = row_mapper.values | |
| new_index = indexes.get_index(row_mapper, 0) | |
| col_idxs = 0 | |
| new_columns = columns | |
| new_ndim = 1 | |
| else: | |
| init_col_mapper_values = reshaping.broadcast_array_to(np.arange(n_cols)[None], (n_rows, n_cols)) | |
| init_col_mapper = i_wrapper.wrap(init_col_mapper_values, index=index, columns=columns) | |
| col_mapper = pd_indexing_func(init_col_mapper) | |
| if checks.is_frame(col_mapper): | |
| # Multiple rows and columns selected | |
| row_idxs = row_mapper.values[:, 0] | |
| col_idxs = col_mapper.values[0] | |
| new_index = indexes.get_index(row_mapper, 0) | |
| new_columns = indexes.get_index(col_mapper, 1) | |
| new_ndim = 2 | |
| elif checks.is_series(col_mapper): | |
| multi_index = isinstance(index, pd.MultiIndex) | |
| multi_columns = isinstance(columns, pd.MultiIndex) | |
| multi_name = isinstance(col_mapper.name, tuple) | |
| if multi_index and multi_name and col_mapper.name in index: | |
| one_row = True | |
| elif not multi_index and not multi_name and col_mapper.name in index: | |
| one_row = True | |
| else: | |
| one_row = False | |
| if multi_columns and multi_name and col_mapper.name in columns: | |
| one_col = True | |
| elif not multi_columns and not multi_name and col_mapper.name in columns: | |
| one_col = True | |
| else: | |
| one_col = False | |
| if (one_row and one_col) or (not one_row and not one_col): | |
| one_row = np.all(row_mapper.values == row_mapper.values.item(0)) | |
| one_col = np.all(col_mapper.values == col_mapper.values.item(0)) | |
| if (one_row and one_col) or (not one_row and not one_col): | |
| raise IndexingError("Could not parse indexing operation") | |
| if one_row: | |
| # One row selected | |
| row_idxs = row_mapper.values[[0]] | |
| col_idxs = col_mapper.values | |
| new_index = index[row_idxs] | |
| new_columns = indexes.get_index(col_mapper, 0) | |
| new_ndim = 2 | |
| else: | |
| # One column selected | |
| row_idxs = row_mapper.values | |
| col_idxs = col_mapper.values[0] | |
| new_index = indexes.get_index(row_mapper, 0) | |
| new_columns = columns[[col_idxs]] | |
| new_ndim = 1 | |
| else: | |
| # One row and column selected | |
| row_idxs = np.array([row_mapper]) | |
| col_idxs = col_mapper | |
| new_index = index[row_idxs] | |
| new_columns = columns[[col_idxs]] | |
| new_ndim = 1 | |
| if _self.grouper.is_grouped(): | |
| # Grouping enabled | |
| if np.asarray(row_idxs).ndim == 0: | |
| raise IndexingError("Flipping index and columns is not allowed") | |
| if group_select: | |
| # Selection based on groups | |
| # Get indices of columns corresponding to selected groups | |
| group_idxs = col_idxs | |
| col_idxs, new_groups = _self.grouper.select_groups(group_idxs) | |
| ungrouped_columns = _self.columns[col_idxs] | |
| if new_ndim == 1 and len(ungrouped_columns) == 1: | |
| ungrouped_ndim = 1 | |
| col_idxs = col_idxs[0] | |
| else: | |
| ungrouped_ndim = 2 | |
| row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0]) | |
| if range_only_select and rows_changed: | |
| if not isinstance(row_idxs, slice): | |
| raise ValueError("Rows can be selected only by slicing") | |
| if row_idxs.step not in (1, None): | |
| raise ValueError("Slice for selecting rows must have a step of 1 or None") | |
| col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1]) | |
| group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1]) | |
| return dict( | |
| new_wrapper=_self.replace( | |
| **merge_dicts( | |
| dict( | |
| index=new_index, | |
| columns=ungrouped_columns, | |
| ndim=ungrouped_ndim, | |
| grouped_ndim=new_ndim, | |
| group_by=new_columns[new_groups], | |
| ), | |
| wrapper_kwargs, | |
| ) | |
| ), | |
| row_idxs=row_idxs, | |
| rows_changed=rows_changed, | |
| col_idxs=col_idxs, | |
| columns_changed=columns_changed, | |
| group_idxs=group_idxs, | |
| groups_changed=groups_changed, | |
| ) | |
| # Selection based on columns | |
| group_idxs = _self.grouper.get_groups()[col_idxs] | |
| new_group_by = _self.grouper.group_by[reshaping.to_1d_array(col_idxs)] | |
| row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0]) | |
| if range_only_select and rows_changed: | |
| if not isinstance(row_idxs, slice): | |
| raise ValueError("Rows can be selected only by slicing") | |
| if row_idxs.step not in (1, None): | |
| raise ValueError("Slice for selecting rows must have a step of 1 or None") | |
| col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1]) | |
| group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1]) | |
| return dict( | |
| new_wrapper=_self.replace( | |
| **merge_dicts( | |
| dict( | |
| index=new_index, | |
| columns=new_columns, | |
| ndim=new_ndim, | |
| grouped_ndim=None, | |
| group_by=new_group_by, | |
| ), | |
| wrapper_kwargs, | |
| ) | |
| ), | |
| row_idxs=row_idxs, | |
| rows_changed=rows_changed, | |
| col_idxs=col_idxs, | |
| columns_changed=columns_changed, | |
| group_idxs=group_idxs, | |
| groups_changed=groups_changed, | |
| ) | |
| # Grouping disabled | |
| row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0]) | |
| if range_only_select and rows_changed: | |
| if not isinstance(row_idxs, slice): | |
| raise ValueError("Rows can be selected only by slicing") | |
| if row_idxs.step not in (1, None): | |
| raise ValueError("Slice for selecting rows must have a step of 1 or None") | |
| col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1]) | |
| return dict( | |
| new_wrapper=_self.replace( | |
| **merge_dicts( | |
| dict( | |
| index=new_index, | |
| columns=new_columns, | |
| ndim=new_ndim, | |
| grouped_ndim=None, | |
| group_by=None, | |
| ), | |
| wrapper_kwargs, | |
| ) | |
| ), | |
| row_idxs=row_idxs, | |
| rows_changed=rows_changed, | |
| col_idxs=col_idxs, | |
| columns_changed=columns_changed, | |
| group_idxs=col_idxs, | |
| groups_changed=columns_changed, | |
| ) | |
| def indexing_func(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT: | |
| """Perform indexing on `ArrayWrapper`.""" | |
| return self.indexing_func_meta(*args, **kwargs)["new_wrapper"] | |
| @staticmethod | |
| def select_from_flex_array( | |
| arr: tp.ArrayLike, | |
| row_idxs: tp.Union[int, tp.Array1d, slice] = None, | |
| col_idxs: tp.Union[int, tp.Array1d, slice] = None, | |
| rows_changed: bool = True, | |
| columns_changed: bool = True, | |
| rotate_rows: bool = False, | |
| rotate_cols: bool = True, | |
| ) -> tp.Array2d: | |
| """Select rows and columns from a flexible array. | |
| Always returns a 2-dim NumPy array.""" | |
| new_arr = arr_2d = reshaping.to_2d_array(arr) | |
| if row_idxs is not None and rows_changed: | |
| if arr_2d.shape[0] > 1: | |
| if isinstance(row_idxs, slice): | |
| max_idx = row_idxs.stop - 1 | |
| else: | |
| row_idxs = reshaping.to_1d_array(row_idxs) | |
| max_idx = np.max(row_idxs) | |
| if arr_2d.shape[0] <= max_idx: | |
| if rotate_rows and not isinstance(row_idxs, slice): | |
| new_arr = new_arr[row_idxs % arr_2d.shape[0], :] | |
| else: | |
| new_arr = new_arr[row_idxs, :] | |
| else: | |
| new_arr = new_arr[row_idxs, :] | |
| if col_idxs is not None and columns_changed: | |
| if arr_2d.shape[1] > 1: | |
| if isinstance(col_idxs, slice): | |
| max_idx = col_idxs.stop - 1 | |
| else: | |
| col_idxs = reshaping.to_1d_array(col_idxs) | |
| max_idx = np.max(col_idxs) | |
| if arr_2d.shape[1] <= max_idx: | |
| if rotate_cols and not isinstance(col_idxs, slice): | |
| new_arr = new_arr[:, col_idxs % arr_2d.shape[1]] | |
| else: | |
| new_arr = new_arr[:, col_idxs] | |
| else: | |
| new_arr = new_arr[:, col_idxs] | |
| return new_arr | |
| def get_resampler(self, *args, **kwargs) -> tp.Union[Resampler, tp.PandasResampler]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_resampler`.""" | |
| return self.index_acc.get_resampler(*args, **kwargs) | |
| def resample_meta(self: ArrayWrapperT, *args, wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> dict: | |
| """Perform resampling on `ArrayWrapper` and also return metadata. | |
| `*args` and `**kwargs` are passed to `ArrayWrapper.get_resampler`.""" | |
| resampler = self.get_resampler(*args, **kwargs) | |
| if isinstance(resampler, Resampler): | |
| _resampler = resampler | |
| else: | |
| _resampler = Resampler.from_pd_resampler(resampler) | |
| if wrapper_kwargs is None: | |
| wrapper_kwargs = {} | |
| if "index" not in wrapper_kwargs: | |
| wrapper_kwargs["index"] = _resampler.target_index | |
| if "freq" not in wrapper_kwargs: | |
| wrapper_kwargs["freq"] = dt.infer_index_freq(wrapper_kwargs["index"], freq=_resampler.target_freq) | |
| new_wrapper = self.replace(**wrapper_kwargs) | |
| return dict(resampler=resampler, new_wrapper=new_wrapper) | |
| def resample(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT: | |
| """Perform resampling on `ArrayWrapper`. | |
| Uses `ArrayWrapper.resample_meta`.""" | |
| return self.resample_meta(*args, **kwargs)["new_wrapper"] | |
| @property | |
| def wrapper(self) -> "ArrayWrapper": | |
| return self | |
| @property | |
| def index(self) -> tp.Index: | |
| """Index.""" | |
| return self._index | |
| @cached_property(whitelist=True) | |
| def index_acc(self) -> BaseIDXAccessorT: | |
| """Get index accessor of the type `vectorbtpro.base.accessors.BaseIDXAccessor`.""" | |
| from vectorbtpro.base.accessors import BaseIDXAccessor | |
| return BaseIDXAccessor(self.index, freq=self._freq) | |
| @property | |
| def ns_index(self) -> tp.Array1d: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.to_ns`.""" | |
| return self.index_acc.to_ns() | |
| def get_period_ns_index(self, *args, **kwargs) -> tp.Array1d: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.to_period_ns`.""" | |
| return self.index_acc.to_period_ns(*args, **kwargs) | |
| @property | |
| def columns(self) -> tp.Index: | |
| """Columns.""" | |
| return self._columns | |
| def get_columns(self, group_by: tp.GroupByLike = None) -> tp.Index: | |
| """Get group-aware `ArrayWrapper.columns`.""" | |
| return self.resolve(group_by=group_by).columns | |
| @property | |
| def name(self) -> tp.Any: | |
| """Name.""" | |
| if self.ndim == 1: | |
| if self.columns[0] == 0: | |
| return None | |
| return self.columns[0] | |
| return None | |
| def get_name(self, group_by: tp.GroupByLike = None) -> tp.Any: | |
| """Get group-aware `ArrayWrapper.name`.""" | |
| return self.resolve(group_by=group_by).name | |
| @property | |
| def ndim(self) -> int: | |
| """Number of dimensions.""" | |
| return self._ndim | |
| def get_ndim(self, group_by: tp.GroupByLike = None) -> int: | |
| """Get group-aware `ArrayWrapper.ndim`.""" | |
| return self.resolve(group_by=group_by).ndim | |
| @property | |
| def shape(self) -> tp.Shape: | |
| """Shape.""" | |
| if self.ndim == 1: | |
| return (len(self.index),) | |
| return len(self.index), len(self.columns) | |
| def get_shape(self, group_by: tp.GroupByLike = None) -> tp.Shape: | |
| """Get group-aware `ArrayWrapper.shape`.""" | |
| return self.resolve(group_by=group_by).shape | |
| @property | |
| def shape_2d(self) -> tp.Shape: | |
| """Shape as if the instance was two-dimensional.""" | |
| if self.ndim == 1: | |
| return self.shape[0], 1 | |
| return self.shape | |
| def get_shape_2d(self, group_by: tp.GroupByLike = None) -> tp.Shape: | |
| """Get group-aware `ArrayWrapper.shape_2d`.""" | |
| return self.resolve(group_by=group_by).shape_2d | |
| def get_freq(self, *args, **kwargs) -> tp.Union[None, float, tp.PandasFrequency]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_freq`.""" | |
| return self.index_acc.get_freq(*args, **kwargs) | |
| @property | |
| def freq(self) -> tp.Optional[tp.PandasFrequency]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.freq`.""" | |
| return self.index_acc.freq | |
| @property | |
| def ns_freq(self) -> tp.Optional[int]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.ns_freq`.""" | |
| return self.index_acc.ns_freq | |
| @property | |
| def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.any_freq`.""" | |
| return self.index_acc.any_freq | |
| @property | |
| def periods(self) -> int: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.periods`.""" | |
| return self.index_acc.periods | |
| @property | |
| def dt_periods(self) -> float: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.dt_periods`.""" | |
| return self.index_acc.dt_periods | |
| def arr_to_timedelta(self, *args, **kwargs) -> tp.Union[pd.Index, tp.MaybeArray]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.arr_to_timedelta`.""" | |
| return self.index_acc.arr_to_timedelta(*args, **kwargs) | |
| @property | |
| def parse_index(self) -> tp.Optional[bool]: | |
| """Whether to try to convert the index into a datetime index. | |
| Applied during the initialization and passed to `vectorbtpro.utils.datetime_.prepare_dt_index`.""" | |
| return self._parse_index | |
| @property | |
| def column_only_select(self) -> bool: | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| column_only_select = self._column_only_select | |
| if column_only_select is None: | |
| column_only_select = wrapping_cfg["column_only_select"] | |
| return column_only_select | |
| @property | |
| def range_only_select(self) -> bool: | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| range_only_select = self._range_only_select | |
| if range_only_select is None: | |
| range_only_select = wrapping_cfg["range_only_select"] | |
| return range_only_select | |
| @property | |
| def group_select(self) -> bool: | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| group_select = self._group_select | |
| if group_select is None: | |
| group_select = wrapping_cfg["group_select"] | |
| return group_select | |
| @property | |
| def grouper(self) -> Grouper: | |
| """Column grouper.""" | |
| return self._grouper | |
| @property | |
| def grouped_ndim(self) -> int: | |
| """Number of dimensions under column grouping.""" | |
| if self._grouped_ndim is None: | |
| if self.grouper.is_grouped(): | |
| return 2 if self.grouper.get_group_count() > 1 else 1 | |
| return self.ndim | |
| return self._grouped_ndim | |
| @cached_method(whitelist=True) | |
| def regroup(self: ArrayWrapperT, group_by: tp.GroupByLike, **kwargs) -> ArrayWrapperT: | |
| """Regroup this instance. | |
| Only creates a new instance if grouping has changed, otherwise returns itself.""" | |
| if self.grouper.is_grouping_changed(group_by=group_by): | |
| self.grouper.check_group_by(group_by=group_by) | |
| grouped_ndim = None | |
| if self.grouper.is_grouped(group_by=group_by): | |
| if not self.grouper.is_group_count_changed(group_by=group_by): | |
| grouped_ndim = self.grouped_ndim | |
| return self.replace(grouped_ndim=grouped_ndim, group_by=group_by, **kwargs) | |
| if len(kwargs) > 0: | |
| return self.replace(**kwargs) | |
| return self # important for keeping cache | |
| def flip(self: ArrayWrapperT, **kwargs) -> ArrayWrapperT: | |
| """Flip index and columns.""" | |
| if "grouper" not in kwargs: | |
| kwargs["grouper"] = None | |
| return self.replace(index=self.columns, columns=self.index, **kwargs) | |
| @cached_method(whitelist=True) | |
| def resolve(self: ArrayWrapperT, group_by: tp.GroupByLike = None, **kwargs) -> ArrayWrapperT: | |
| """Resolve this instance. | |
| Replaces columns and other metadata with groups.""" | |
| _self = self.regroup(group_by=group_by, **kwargs) | |
| if _self.grouper.is_grouped(): | |
| return _self.replace( | |
| columns=_self.grouper.get_index(), | |
| ndim=_self.grouped_ndim, | |
| grouped_ndim=None, | |
| group_by=None, | |
| ) | |
| return _self # important for keeping cache | |
| def get_index_grouper(self, *args, **kwargs) -> Grouper: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.""" | |
| return self.index_acc.get_grouper(*args, **kwargs) | |
| def wrap( | |
| self, | |
| arr: tp.ArrayLike, | |
| group_by: tp.GroupByLike = None, | |
| index: tp.Optional[tp.IndexLike] = None, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| zero_to_none: tp.Optional[bool] = None, | |
| force_2d: bool = False, | |
| fillna: tp.Optional[tp.Scalar] = None, | |
| dtype: tp.Optional[tp.PandasDTypeLike] = None, | |
| min_precision: tp.Union[None, int, str] = None, | |
| max_precision: tp.Union[None, int, str] = None, | |
| prec_float_only: tp.Optional[bool] = None, | |
| prec_check_bounds: tp.Optional[bool] = None, | |
| prec_strict: tp.Optional[bool] = None, | |
| to_timedelta: bool = False, | |
| to_index: bool = False, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.SeriesFrame: | |
| """Wrap a NumPy array using the stored metadata. | |
| Runs the following pipeline: | |
| 1) Converts to NumPy array | |
| 2) Fills NaN (optional) | |
| 3) Wraps using index, columns, and dtype (optional) | |
| 4) Converts to index (optional) | |
| 5) Converts to timedelta using `ArrayWrapper.arr_to_timedelta` (optional)""" | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| if zero_to_none is None: | |
| zero_to_none = wrapping_cfg["zero_to_none"] | |
| if min_precision is None: | |
| min_precision = wrapping_cfg["min_precision"] | |
| if max_precision is None: | |
| max_precision = wrapping_cfg["max_precision"] | |
| if prec_float_only is None: | |
| prec_float_only = wrapping_cfg["prec_float_only"] | |
| if prec_check_bounds is None: | |
| prec_check_bounds = wrapping_cfg["prec_check_bounds"] | |
| if prec_strict is None: | |
| prec_strict = wrapping_cfg["prec_strict"] | |
| if silence_warnings is None: | |
| silence_warnings = wrapping_cfg["silence_warnings"] | |
| _self = self.resolve(group_by=group_by) | |
| if index is None: | |
| index = _self.index | |
| if not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| if columns is None: | |
| columns = _self.columns | |
| if not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| if len(columns) == 1: | |
| name = columns[0] | |
| if zero_to_none and name == 0: # was a Series before | |
| name = None | |
| else: | |
| name = None | |
| def _apply_dtype(obj): | |
| if dtype is None: | |
| return obj | |
| return obj.astype(dtype, errors="ignore") | |
| def _wrap(arr): | |
| orig_arr = arr | |
| arr = np.asarray(arr) | |
| if fillna is not None: | |
| arr[pd.isnull(arr)] = fillna | |
| shape_2d = (arr.shape[0] if arr.ndim > 0 else 1, arr.shape[1] if arr.ndim > 1 else 1) | |
| target_shape_2d = (len(index), len(columns)) | |
| if shape_2d != target_shape_2d: | |
| if isinstance(orig_arr, (pd.Series, pd.DataFrame)): | |
| arr = reshaping.align_pd_arrays(orig_arr, to_index=index, to_columns=columns).values | |
| arr = reshaping.broadcast_array_to(arr, target_shape_2d) | |
| arr = reshaping.soft_to_ndim(arr, self.ndim) | |
| if min_precision is not None: | |
| arr = cast_to_min_precision( | |
| arr, | |
| min_precision, | |
| float_only=prec_float_only, | |
| ) | |
| if max_precision is not None: | |
| arr = cast_to_max_precision( | |
| arr, | |
| max_precision, | |
| float_only=prec_float_only, | |
| check_bounds=prec_check_bounds, | |
| strict=prec_strict, | |
| ) | |
| if arr.ndim == 1: | |
| if force_2d: | |
| return _apply_dtype(pd.DataFrame(arr[:, None], index=index, columns=columns)) | |
| return _apply_dtype(pd.Series(arr, index=index, name=name)) | |
| if arr.ndim == 2: | |
| if not force_2d and arr.shape[1] == 1 and _self.ndim == 1: | |
| return _apply_dtype(pd.Series(arr[:, 0], index=index, name=name)) | |
| return _apply_dtype(pd.DataFrame(arr, index=index, columns=columns)) | |
| raise ValueError(f"{arr.ndim}-d input is not supported") | |
| out = _wrap(arr) | |
| if to_index: | |
| # Convert to index | |
| if checks.is_series(out): | |
| out = out.map(lambda x: self.index[x] if x != -1 else np.nan) | |
| else: | |
| out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan) | |
| if to_timedelta: | |
| # Convert to timedelta | |
| out = self.arr_to_timedelta(out, silence_warnings=silence_warnings) | |
| return out | |
| def wrap_reduced( | |
| self, | |
| arr: tp.ArrayLike, | |
| group_by: tp.GroupByLike = None, | |
| name_or_index: tp.NameIndex = None, | |
| columns: tp.Optional[tp.IndexLike] = None, | |
| force_1d: bool = False, | |
| fillna: tp.Optional[tp.Scalar] = None, | |
| dtype: tp.Optional[tp.PandasDTypeLike] = None, | |
| to_timedelta: bool = False, | |
| to_index: bool = False, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.MaybeSeriesFrame: | |
| """Wrap result of reduction. | |
| `name_or_index` can be the name of the resulting series if reducing to a scalar per column, | |
| or the index of the resulting series/dataframe if reducing to an array per column. | |
| `columns` can be set to override object's default columns. | |
| See `ArrayWrapper.wrap` for the pipeline.""" | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| if silence_warnings is None: | |
| silence_warnings = wrapping_cfg["silence_warnings"] | |
| _self = self.resolve(group_by=group_by) | |
| if columns is None: | |
| columns = _self.columns | |
| if not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| if to_index: | |
| if dtype is None: | |
| dtype = int_ | |
| if fillna is None: | |
| fillna = -1 | |
| def _apply_dtype(obj): | |
| if dtype is None: | |
| return obj | |
| return obj.astype(dtype, errors="ignore") | |
| def _wrap_reduced(arr): | |
| nonlocal name_or_index | |
| if isinstance(arr, dict): | |
| arr = reshaping.to_pd_array(arr) | |
| if isinstance(arr, pd.Series): | |
| if not checks.is_index_equal(arr.index, columns): | |
| arr = arr.iloc[indexes.align_indexes(arr.index, columns)[0]] | |
| arr = np.asarray(arr) | |
| if force_1d and arr.ndim == 0: | |
| arr = arr[None] | |
| if fillna is not None: | |
| if arr.ndim == 0: | |
| if pd.isnull(arr): | |
| arr = fillna | |
| else: | |
| arr[pd.isnull(arr)] = fillna | |
| if arr.ndim == 0: | |
| # Scalar per Series/DataFrame | |
| return _apply_dtype(pd.Series(arr[None]))[0] | |
| if arr.ndim == 1: | |
| if not force_1d and _self.ndim == 1: | |
| if arr.shape[0] == 1: | |
| # Scalar per Series/DataFrame with one column | |
| return _apply_dtype(pd.Series(arr))[0] | |
| # Array per Series | |
| sr_name = columns[0] | |
| if sr_name == 0: | |
| sr_name = None | |
| if isinstance(name_or_index, str): | |
| name_or_index = None | |
| return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name)) | |
| # Scalar per column in DataFrame | |
| if arr.shape[0] == 1 and len(columns) > 1: | |
| arr = reshaping.broadcast_array_to(arr, len(columns)) | |
| return _apply_dtype(pd.Series(arr, index=columns, name=name_or_index)) | |
| if arr.ndim == 2: | |
| if arr.shape[1] == 1 and _self.ndim == 1: | |
| arr = reshaping.soft_to_ndim(arr, 1) | |
| # Array per Series | |
| sr_name = columns[0] | |
| if sr_name == 0: | |
| sr_name = None | |
| if isinstance(name_or_index, str): | |
| name_or_index = None | |
| return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name)) | |
| # Array per column in DataFrame | |
| if isinstance(name_or_index, str): | |
| name_or_index = None | |
| if arr.shape[0] == 1 and len(columns) > 1: | |
| arr = reshaping.broadcast_array_to(arr, (arr.shape[0], len(columns))) | |
| return _apply_dtype(pd.DataFrame(arr, index=name_or_index, columns=columns)) | |
| raise ValueError(f"{arr.ndim}-d input is not supported") | |
| out = _wrap_reduced(arr) | |
| if to_index: | |
| # Convert to index | |
| if checks.is_series(out): | |
| out = out.map(lambda x: self.index[x] if x != -1 else np.nan) | |
| elif checks.is_frame(out): | |
| out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan) | |
| else: | |
| out = self.index[out] if out != -1 else np.nan | |
| if to_timedelta: | |
| # Convert to timedelta | |
| out = self.arr_to_timedelta(out, silence_warnings=silence_warnings) | |
| return out | |
| def concat_arrs( | |
| self, | |
| *objs: tp.ArrayLike, | |
| group_by: tp.GroupByLike = None, | |
| wrap: bool = True, | |
| **kwargs, | |
| ) -> tp.AnyArray1d: | |
| """Stack reduced objects along columns and wrap the final object.""" | |
| from vectorbtpro.base.merging import concat_arrays | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| new_objs = [] | |
| for obj in objs: | |
| new_objs.append(reshaping.to_1d_array(obj)) | |
| stacked_obj = concat_arrays(new_objs) | |
| if wrap: | |
| _self = self.resolve(group_by=group_by) | |
| return _self.wrap_reduced(stacked_obj, **kwargs) | |
| return stacked_obj | |
| def row_stack_arrs( | |
| self, | |
| *objs: tp.ArrayLike, | |
| group_by: tp.GroupByLike = None, | |
| wrap: bool = True, | |
| **kwargs, | |
| ) -> tp.AnyArray: | |
| """Stack objects along rows and wrap the final object.""" | |
| from vectorbtpro.base.merging import row_stack_arrays | |
| _self = self.resolve(group_by=group_by) | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| new_objs = [] | |
| for obj in objs: | |
| obj = reshaping.to_2d_array(obj) | |
| if obj.shape[1] != _self.shape_2d[1]: | |
| if obj.shape[1] != 1: | |
| raise ValueError(f"Cannot broadcast {obj.shape[1]} to {_self.shape_2d[1]} columns") | |
| obj = np.repeat(obj, _self.shape_2d[1], axis=1) | |
| new_objs.append(obj) | |
| stacked_obj = row_stack_arrays(new_objs) | |
| if wrap: | |
| return _self.wrap(stacked_obj, **kwargs) | |
| return stacked_obj | |
| def column_stack_arrs( | |
| self, | |
| *objs: tp.ArrayLike, | |
| reindex_kwargs: tp.KwargsLike = None, | |
| group_by: tp.GroupByLike = None, | |
| wrap: bool = True, | |
| **kwargs, | |
| ) -> tp.AnyArray2d: | |
| """Stack objects along columns and wrap the final object. | |
| `reindex_kwargs` will be passed to | |
| [pandas.DataFrame.reindex](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.reindex.html).""" | |
| from vectorbtpro.base.merging import column_stack_arrays | |
| _self = self.resolve(group_by=group_by) | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| new_objs = [] | |
| for obj in objs: | |
| if not checks.is_index_equal(obj.index, _self.index, check_names=False): | |
| was_bool = (isinstance(obj, pd.Series) and obj.dtype == "bool") or ( | |
| isinstance(obj, pd.DataFrame) and (obj.dtypes == "bool").all() | |
| ) | |
| obj = obj.reindex(_self.index, **resolve_dict(reindex_kwargs)) | |
| is_object = (isinstance(obj, pd.Series) and obj.dtype == "object") or ( | |
| isinstance(obj, pd.DataFrame) and (obj.dtypes == "object").all() | |
| ) | |
| if was_bool and is_object: | |
| obj = obj.astype(None) | |
| new_objs.append(reshaping.to_2d_array(obj)) | |
| stacked_obj = column_stack_arrays(new_objs) | |
| if wrap: | |
| return _self.wrap(stacked_obj, **kwargs) | |
| return stacked_obj | |
| def dummy(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: | |
| """Create a dummy Series/DataFrame.""" | |
| _self = self.resolve(group_by=group_by) | |
| return _self.wrap(np.empty(_self.shape), **kwargs) | |
| def fill(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: | |
| """Fill a Series/DataFrame.""" | |
| _self = self.resolve(group_by=group_by) | |
| return _self.wrap(np.full(_self.shape_2d, fill_value), **kwargs) | |
| def fill_reduced(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: | |
| """Fill a reduced Series/DataFrame.""" | |
| _self = self.resolve(group_by=group_by) | |
| return _self.wrap_reduced(np.full(_self.shape_2d[1], fill_value), **kwargs) | |
| def apply_to_index( | |
| self: ArrayWrapperT, | |
| apply_func: tp.Callable, | |
| *args, | |
| axis: tp.Optional[int] = None, | |
| **kwargs, | |
| ) -> ArrayWrapperT: | |
| if axis is None: | |
| axis = 0 if self.ndim == 1 else 1 | |
| if self.ndim == 1 and axis == 1: | |
| raise TypeError("Axis 1 is not supported for one dimension") | |
| checks.assert_in(axis, (0, 1)) | |
| if axis == 1: | |
| return self.replace(columns=apply_func(self.columns, *args, **kwargs)) | |
| return self.replace(index=apply_func(self.index, *args, **kwargs)) | |
| def get_index_points(self, *args, **kwargs) -> tp.Array1d: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_points`.""" | |
| return self.index_acc.get_points(*args, **kwargs) | |
| def get_index_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]: | |
| """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_ranges`.""" | |
| return self.index_acc.get_ranges(*args, **kwargs) | |
| def fill_and_set( | |
| self, | |
| idx_setter: tp.Union[index_dict, IdxSetter, IdxSetterFactory], | |
| keep_flex: bool = False, | |
| fill_value: tp.Scalar = np.nan, | |
| **kwargs, | |
| ) -> tp.AnyArray: | |
| """Fill a new array using an index object such as `vectorbtpro.base.indexing.index_dict`. | |
| Will be wrapped with `vectorbtpro.base.indexing.IdxSetter` if not already. | |
| Will call `vectorbtpro.base.indexing.IdxSetter.fill_and_set`. | |
| Usage: | |
| * Set a single row: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> index = pd.date_range("2020", periods=5) | |
| >>> columns = pd.Index(["a", "b", "c"]) | |
| >>> wrapper = vbt.ArrayWrapper(index, columns) | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... 1: 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 2.0 2.0 2.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... "2020-01-02": 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 2.0 2.0 2.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... "2020-01-02": [1, 2, 3] | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 1.0 2.0 3.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| ``` | |
| * Set multiple rows: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... (1, 3): [2, 3] | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 2.0 2.0 2.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 3.0 3.0 3.0 | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... ("2020-01-02", "2020-01-04"): [[1, 2, 3], [4, 5, 6]] | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 1.0 2.0 3.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 4.0 5.0 6.0 | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... ("2020-01-02", "2020-01-04"): [[1, 2, 3]] | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 1.0 2.0 3.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 1.0 2.0 3.0 | |
| 2020-01-05 NaN NaN NaN | |
| ``` | |
| * Set rows using slices: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.hslice(1, 3): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 2.0 2.0 2.0 | |
| 2020-01-03 2.0 2.0 2.0 | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.hslice("2020-01-02", "2020-01-04"): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 2.0 2.0 2.0 | |
| 2020-01-03 2.0 2.0 2.0 | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... ((0, 2), (3, 5)): [[1], [2]] | |
| ... })) | |
| a b c | |
| 2020-01-01 1.0 1.0 1.0 | |
| 2020-01-02 1.0 1.0 1.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 2.0 2.0 2.0 | |
| 2020-01-05 2.0 2.0 2.0 | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... ((0, 2), (3, 5)): [[1, 2, 3], [4, 5, 6]] | |
| ... })) | |
| a b c | |
| 2020-01-01 1.0 2.0 3.0 | |
| 2020-01-02 1.0 2.0 3.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 4.0 5.0 6.0 | |
| 2020-01-05 4.0 5.0 6.0 | |
| ``` | |
| * Set rows using index points: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.pointidx(every="2D"): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 2.0 2.0 2.0 | |
| 2020-01-02 NaN NaN NaN | |
| 2020-01-03 2.0 2.0 2.0 | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 2.0 2.0 2.0 | |
| ``` | |
| * Set rows using index ranges: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.rangeidx( | |
| ... start=("2020-01-01", "2020-01-03"), | |
| ... end=("2020-01-02", "2020-01-05") | |
| ... ): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 2.0 2.0 2.0 | |
| 2020-01-02 NaN NaN NaN | |
| 2020-01-03 2.0 2.0 2.0 | |
| 2020-01-04 2.0 2.0 2.0 | |
| 2020-01-05 NaN NaN NaN | |
| ``` | |
| * Set column indices: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.colidx("a"): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 2.0 NaN NaN | |
| 2020-01-02 2.0 NaN NaN | |
| 2020-01-03 2.0 NaN NaN | |
| 2020-01-04 2.0 NaN NaN | |
| 2020-01-05 2.0 NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.colidx(("a", "b")): [1, 2] | |
| ... })) | |
| a b c | |
| 2020-01-01 1.0 2.0 NaN | |
| 2020-01-02 1.0 2.0 NaN | |
| 2020-01-03 1.0 2.0 NaN | |
| 2020-01-04 1.0 2.0 NaN | |
| 2020-01-05 1.0 2.0 NaN | |
| >>> multi_columns = pd.MultiIndex.from_arrays( | |
| ... [["a", "a", "b", "b"], [1, 2, 1, 2]], | |
| ... names=["c1", "c2"] | |
| ... ) | |
| >>> multi_wrapper = vbt.ArrayWrapper(index, multi_columns) | |
| >>> multi_wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.colidx(("a", 2)): 2 | |
| ... })) | |
| c1 a b | |
| c2 1 2 1 2 | |
| 2020-01-01 NaN 2.0 NaN NaN | |
| 2020-01-02 NaN 2.0 NaN NaN | |
| 2020-01-03 NaN 2.0 NaN NaN | |
| 2020-01-04 NaN 2.0 NaN NaN | |
| 2020-01-05 NaN 2.0 NaN NaN | |
| >>> multi_wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.colidx("b", level="c1"): [3, 4] | |
| ... })) | |
| c1 a b | |
| c2 1 2 1 2 | |
| 2020-01-01 NaN NaN 3.0 4.0 | |
| 2020-01-02 NaN NaN 3.0 4.0 | |
| 2020-01-03 NaN NaN 3.0 4.0 | |
| 2020-01-04 NaN NaN 3.0 4.0 | |
| 2020-01-05 NaN NaN 3.0 4.0 | |
| ``` | |
| * Set row and column indices: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.idx(2, 2): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 NaN NaN NaN | |
| 2020-01-03 NaN NaN 2.0 | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.idx(("2020-01-01", "2020-01-03"), 2): [1, 2] | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN 1.0 | |
| 2020-01-02 NaN NaN NaN | |
| 2020-01-03 NaN NaN 2.0 | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.idx(("2020-01-01", "2020-01-03"), (0, 2)): [[1, 2], [3, 4]] | |
| ... })) | |
| a b c | |
| 2020-01-01 1.0 NaN 2.0 | |
| 2020-01-02 NaN NaN NaN | |
| 2020-01-03 3.0 NaN 4.0 | |
| 2020-01-04 NaN NaN NaN | |
| 2020-01-05 NaN NaN NaN | |
| >>> multi_wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.idx( | |
| ... vbt.pointidx(every="2d"), | |
| ... vbt.colidx(1, level="c2") | |
| ... ): [[1, 2]] | |
| ... })) | |
| c1 a b | |
| c2 1 2 1 2 | |
| 2020-01-01 1.0 NaN 2.0 NaN | |
| 2020-01-02 NaN NaN NaN NaN | |
| 2020-01-03 1.0 NaN 2.0 NaN | |
| 2020-01-04 NaN NaN NaN NaN | |
| 2020-01-05 1.0 NaN 2.0 NaN | |
| >>> multi_wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.idx( | |
| ... vbt.pointidx(every="2d"), | |
| ... vbt.colidx(1, level="c2") | |
| ... ): [[1], [2], [3]] | |
| ... })) | |
| c1 a b | |
| c2 1 2 1 2 | |
| 2020-01-01 1.0 NaN 1.0 NaN | |
| 2020-01-02 NaN NaN NaN NaN | |
| 2020-01-03 2.0 NaN 2.0 NaN | |
| 2020-01-04 NaN NaN NaN NaN | |
| 2020-01-05 3.0 NaN 3.0 NaN | |
| ``` | |
| * Set rows using a template: | |
| ```pycon | |
| >>> wrapper.fill_and_set(vbt.index_dict({ | |
| ... vbt.RepEval("index.day % 2 == 0"): 2 | |
| ... })) | |
| a b c | |
| 2020-01-01 NaN NaN NaN | |
| 2020-01-02 2.0 2.0 2.0 | |
| 2020-01-03 NaN NaN NaN | |
| 2020-01-04 2.0 2.0 2.0 | |
| 2020-01-05 NaN NaN NaN | |
| ``` | |
| """ | |
| if isinstance(idx_setter, index_dict): | |
| idx_setter = IdxDict(idx_setter) | |
| if isinstance(idx_setter, IdxSetterFactory): | |
| idx_setter = idx_setter.get() | |
| if not isinstance(idx_setter, IdxSetter): | |
| raise ValueError("Index setter factory must return exactly one index setter") | |
| checks.assert_instance_of(idx_setter, IdxSetter) | |
| arr = idx_setter.fill_and_set( | |
| self.shape, | |
| keep_flex=keep_flex, | |
| fill_value=fill_value, | |
| index=self.index, | |
| columns=self.columns, | |
| freq=self.freq, | |
| **kwargs, | |
| ) | |
| if not keep_flex: | |
| return self.wrap(arr, group_by=False) | |
| return arr | |
| WrappingT = tp.TypeVar("WrappingT", bound="Wrapping") | |
| class Wrapping(Configured, HasWrapper, IndexApplier, AttrResolverMixin): | |
| """Class that uses `ArrayWrapper` globally.""" | |
| @classmethod | |
| def resolve_row_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs: | |
| """Resolve keyword arguments for initializing `Wrapping` after stacking along rows.""" | |
| return kwargs | |
| @classmethod | |
| def resolve_column_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs: | |
| """Resolve keyword arguments for initializing `Wrapping` after stacking along columns.""" | |
| return kwargs | |
| @classmethod | |
| def resolve_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs: | |
| """Resolve keyword arguments for initializing `Wrapping` after stacking. | |
| Should be called after `Wrapping.resolve_row_stack_kwargs` or `Wrapping.resolve_column_stack_kwargs`.""" | |
| return cls.resolve_merge_kwargs(*[wrapping.config for wrapping in wrappings], **kwargs) | |
| @hybrid_method | |
| def row_stack( | |
| cls_or_self: tp.MaybeType[WrappingT], | |
| *objs: tp.MaybeTuple[WrappingT], | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> WrappingT: | |
| """Stack multiple `Wrapping` instances along rows. | |
| Should use `ArrayWrapper.row_stack`.""" | |
| raise NotImplementedError | |
| @hybrid_method | |
| def column_stack( | |
| cls_or_self: tp.MaybeType[WrappingT], | |
| *objs: tp.MaybeTuple[WrappingT], | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> WrappingT: | |
| """Stack multiple `Wrapping` instances along columns. | |
| Should use `ArrayWrapper.column_stack`.""" | |
| raise NotImplementedError | |
| def __init__(self, wrapper: ArrayWrapper, **kwargs) -> None: | |
| checks.assert_instance_of(wrapper, ArrayWrapper) | |
| self._wrapper = wrapper | |
| Configured.__init__(self, wrapper=wrapper, **kwargs) | |
| HasWrapper.__init__(self) | |
| AttrResolverMixin.__init__(self) | |
| def indexing_func(self: WrappingT, *args, **kwargs) -> WrappingT: | |
| """Perform indexing on `Wrapping`.""" | |
| new_wrapper = self.wrapper.indexing_func( | |
| *args, | |
| column_only_select=self.column_only_select, | |
| range_only_select=self.range_only_select, | |
| group_select=self.group_select, | |
| **kwargs, | |
| ) | |
| return self.replace(wrapper=new_wrapper) | |
| def resample(self: WrappingT, *args, **kwargs) -> WrappingT: | |
| """Perform resampling on `Wrapping`. | |
| When overriding, make sure to create a resampler by passing `*args` and `**kwargs` | |
| to `ArrayWrapper.get_resampler`.""" | |
| raise NotImplementedError | |
| @property | |
| def wrapper(self) -> ArrayWrapper: | |
| return self._wrapper | |
| def apply_to_index( | |
| self: ArrayWrapperT, | |
| apply_func: tp.Callable, | |
| *args, | |
| axis: tp.Optional[int] = None, | |
| **kwargs, | |
| ) -> ArrayWrapperT: | |
| if axis is None: | |
| axis = 0 if self.wrapper.ndim == 1 else 1 | |
| if self.wrapper.ndim == 1 and axis == 1: | |
| raise TypeError("Axis 1 is not supported for one dimension") | |
| checks.assert_in(axis, (0, 1)) | |
| if axis == 1: | |
| new_wrapper = self.wrapper.replace(columns=apply_func(self.wrapper.columns, *args, **kwargs)) | |
| else: | |
| new_wrapper = self.wrapper.replace(index=apply_func(self.wrapper.index, *args, **kwargs)) | |
| return self.replace(wrapper=new_wrapper) | |
| @property | |
| def column_only_select(self) -> bool: | |
| column_only_select = getattr(self, "_column_only_select", None) | |
| if column_only_select is None: | |
| return self.wrapper.column_only_select | |
| return column_only_select | |
| @property | |
| def range_only_select(self) -> bool: | |
| range_only_select = getattr(self, "_range_only_select", None) | |
| if range_only_select is None: | |
| return self.wrapper.range_only_select | |
| return range_only_select | |
| @property | |
| def group_select(self) -> bool: | |
| group_select = getattr(self, "_group_select", None) | |
| if group_select is None: | |
| return self.wrapper.group_select | |
| return group_select | |
| def regroup(self: WrappingT, group_by: tp.GroupByLike, **kwargs) -> WrappingT: | |
| """Regroup this instance. | |
| Only creates a new instance if grouping has changed, otherwise returns itself. | |
| `**kwargs` will be passed to `ArrayWrapper.regroup`.""" | |
| if self.wrapper.grouper.is_grouping_changed(group_by=group_by): | |
| self.wrapper.grouper.check_group_by(group_by=group_by) | |
| return self.replace(wrapper=self.wrapper.regroup(group_by, **kwargs)) | |
| return self # important for keeping cache | |
| def resolve_self( | |
| self: AttrResolverMixinT, | |
| cond_kwargs: tp.KwargsLike = None, | |
| custom_arg_names: tp.Optional[tp.Set[str]] = None, | |
| impacts_caching: bool = True, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> AttrResolverMixinT: | |
| """Resolve self. | |
| Creates a copy of this instance if a different `freq` can be found in `cond_kwargs`.""" | |
| from vectorbtpro._settings import settings | |
| wrapping_cfg = settings["wrapping"] | |
| if cond_kwargs is None: | |
| cond_kwargs = {} | |
| if custom_arg_names is None: | |
| custom_arg_names = set() | |
| if silence_warnings is None: | |
| silence_warnings = wrapping_cfg["silence_warnings"] | |
| if "freq" in cond_kwargs: | |
| wrapper_copy = self.wrapper.replace(freq=cond_kwargs["freq"]) | |
| if wrapper_copy.freq != self.wrapper.freq: | |
| if not silence_warnings: | |
| warn( | |
| f"Changing the frequency will create a copy of this instance. " | |
| f"Consider setting it upon instantiation to re-use existing cache." | |
| ) | |
| self_copy = self.replace(wrapper=wrapper_copy) | |
| for alias in self.self_aliases: | |
| if alias not in custom_arg_names: | |
| cond_kwargs[alias] = self_copy | |
| cond_kwargs["freq"] = self_copy.wrapper.freq | |
| if impacts_caching: | |
| cond_kwargs["use_caching"] = False | |
| return self_copy | |
| return self | |
| </file> | |
| <file path="data/custom/__init__.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Modules with custom data classes.""" | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from vectorbtpro.data.custom.alpaca import * | |
| from vectorbtpro.data.custom.av import * | |
| from vectorbtpro.data.custom.bento import * | |
| from vectorbtpro.data.custom.binance import * | |
| from vectorbtpro.data.custom.ccxt import * | |
| from vectorbtpro.data.custom.csv import * | |
| from vectorbtpro.data.custom.custom import * | |
| from vectorbtpro.data.custom.db import * | |
| from vectorbtpro.data.custom.duckdb import * | |
| from vectorbtpro.data.custom.feather import * | |
| from vectorbtpro.data.custom.file import * | |
| from vectorbtpro.data.custom.finpy import * | |
| from vectorbtpro.data.custom.gbm import * | |
| from vectorbtpro.data.custom.gbm_ohlc import * | |
| from vectorbtpro.data.custom.hdf import * | |
| from vectorbtpro.data.custom.local import * | |
| from vectorbtpro.data.custom.ndl import * | |
| from vectorbtpro.data.custom.parquet import * | |
| from vectorbtpro.data.custom.polygon import * | |
| from vectorbtpro.data.custom.random import * | |
| from vectorbtpro.data.custom.random_ohlc import * | |
| from vectorbtpro.data.custom.remote import * | |
| from vectorbtpro.data.custom.sql import * | |
| from vectorbtpro.data.custom.synthetic import * | |
| from vectorbtpro.data.custom.tv import * | |
| from vectorbtpro.data.custom.yf import * | |
| </file> | |
| <file path="data/custom/alpaca.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `AlpacaData`.""" | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from alpaca.common.rest import RESTClient as AlpacaClientT | |
| except ImportError: | |
| AlpacaClientT = "AlpacaClient" | |
| __all__ = [ | |
| "AlpacaData", | |
| ] | |
| AlpacaDataT = tp.TypeVar("AlpacaDataT", bound="AlpacaData") | |
| class AlpacaData(RemoteData): | |
| """Data class for fetching from Alpaca. | |
| See https://github.com/alpacahq/alpaca-py for API. | |
| See `AlpacaData.fetch_symbol` for arguments. | |
| Usage: | |
| * Set up the API key globally (optional for crypto): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.AlpacaData.set_custom_settings( | |
| ... client_config=dict( | |
| ... api_key="YOUR_KEY", | |
| ... secret_key="YOUR_SECRET" | |
| ... ) | |
| ... ) | |
| ``` | |
| * Pull stock data: | |
| ```pycon | |
| >>> data = vbt.AlpacaData.pull( | |
| ... "AAPL", | |
| ... start="2021-01-01", | |
| ... end="2022-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| * Pull crypto data: | |
| ```pycon | |
| >>> data = vbt.AlpacaData.pull( | |
| ... "BTC/USD", | |
| ... start="2021-01-01", | |
| ... end="2022-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.alpaca") | |
| @classmethod | |
| def list_symbols( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| status: tp.Optional[str] = None, | |
| asset_class: tp.Optional[str] = None, | |
| exchange: tp.Optional[str] = None, | |
| trading_client: tp.Optional[AlpacaClientT] = None, | |
| client_config: tp.KwargsLike = None, | |
| ) -> tp.List[str]: | |
| """List all symbols. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`. | |
| Arguments `status`, `asset_class`, and `exchange` can be strings, such as `asset_class="crypto"`. | |
| For possible values, take a look into `alpaca.trading.enums`. | |
| !!! note | |
| If you get an authorization error, make sure that you either enable or disable | |
| the `paper` flag in `client_config` depending upon the account whose credentials you used. | |
| By default, the credentials are assumed to be of a live trading account (`paper=False`).""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("alpaca") | |
| from alpaca.trading.client import TradingClient | |
| from alpaca.trading.requests import GetAssetsRequest | |
| from alpaca.trading.enums import AssetStatus, AssetClass, AssetExchange | |
| if client_config is None: | |
| client_config = {} | |
| has_client_config = len(client_config) > 0 | |
| client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) | |
| if trading_client is None: | |
| arg_names = get_func_arg_names(TradingClient.__init__) | |
| client_config = {k: v for k, v in client_config.items() if k in arg_names} | |
| trading_client = TradingClient(**client_config) | |
| elif has_client_config: | |
| raise ValueError("Cannot apply client_config to already initialized client") | |
| if status is not None: | |
| if isinstance(status, str): | |
| status = getattr(AssetStatus, status.upper()) | |
| if asset_class is not None: | |
| if isinstance(asset_class, str): | |
| asset_class = getattr(AssetClass, asset_class.upper()) | |
| if exchange is not None: | |
| if isinstance(exchange, str): | |
| exchange = getattr(AssetExchange, exchange.upper()) | |
| search_params = GetAssetsRequest(status=status, asset_class=asset_class, exchange=exchange) | |
| assets = trading_client.get_all_assets(search_params) | |
| all_symbols = [] | |
| for asset in assets: | |
| symbol = asset.symbol | |
| if pattern is not None: | |
| if not cls.key_match(symbol, pattern, use_regex=use_regex): | |
| continue | |
| all_symbols.append(symbol) | |
| if sort: | |
| return sorted(dict.fromkeys(all_symbols)) | |
| return list(dict.fromkeys(all_symbols)) | |
| @classmethod | |
| def resolve_client( | |
| cls, | |
| client: tp.Optional[AlpacaClientT] = None, | |
| client_type: tp.Optional[str] = None, | |
| **client_config, | |
| ) -> AlpacaClientT: | |
| """Resolve the client. | |
| If provided, must be of the type `alpaca.data.historical.CryptoHistoricalDataClient` | |
| for `client_type="crypto"` and `alpaca.data.historical.StockHistoricalDataClient` for | |
| `client_type="stocks"`. Otherwise, will be created using `client_config`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("alpaca") | |
| from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient | |
| client = cls.resolve_custom_setting(client, "client") | |
| client_type = cls.resolve_custom_setting(client_type, "client_type") | |
| if client_config is None: | |
| client_config = {} | |
| has_client_config = len(client_config) > 0 | |
| client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) | |
| if client is None: | |
| if client_type == "crypto": | |
| arg_names = get_func_arg_names(CryptoHistoricalDataClient.__init__) | |
| client_config = {k: v for k, v in client_config.items() if k in arg_names} | |
| client = CryptoHistoricalDataClient(**client_config) | |
| elif client_type == "stocks": | |
| arg_names = get_func_arg_names(StockHistoricalDataClient.__init__) | |
| client_config = {k: v for k, v in client_config.items() if k in arg_names} | |
| client = StockHistoricalDataClient(**client_config) | |
| else: | |
| raise ValueError(f"Invalid client type: '{client_type}'") | |
| elif has_client_config: | |
| raise ValueError("Cannot apply client_config to already initialized client") | |
| return client | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| client: tp.Optional[AlpacaClientT] = None, | |
| client_type: tp.Optional[str] = None, | |
| client_config: tp.KwargsLike = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| adjustment: tp.Optional[str] = None, | |
| feed: tp.Optional[str] = None, | |
| limit: tp.Optional[int] = None, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Alpaca. | |
| Args: | |
| symbol (str): Symbol. | |
| client (alpaca.common.rest.RESTClient): Client. | |
| See `AlpacaData.resolve_client`. | |
| client_type (str): Client type. | |
| See `AlpacaData.resolve_client`. | |
| Determined automatically based on the symbol. Crypto symbols contain "/". | |
| client_config (dict): Client config. | |
| See `AlpacaData.resolve_client`. | |
| start (any): Start datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): End datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| adjustment (str): Specifies the corporate action adjustment for the returned bars. | |
| Options are: "raw", "split", "dividend" or "all". Default is "raw". | |
| feed (str): The feed to pull market data from. | |
| This is either "iex", "otc", or "sip". Feeds "sip" and "otc" are only available to | |
| those with a subscription. Default is "iex" for free plans and "sip" for paid. | |
| limit (int): The maximum number of returned items. | |
| For defaults, see `custom.alpaca` in `vectorbtpro._settings.data`. | |
| Global settings can be provided per exchange id using the `exchanges` dictionary. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("alpaca") | |
| from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient | |
| from alpaca.data.requests import CryptoBarsRequest, StockBarsRequest | |
| from alpaca.data.timeframe import TimeFrame, TimeFrameUnit | |
| if client_type is None: | |
| client_type = "crypto" if "/" in symbol else "stocks" | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, client_type=client_type, **client_config) | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| adjustment = cls.resolve_custom_setting(adjustment, "adjustment") | |
| feed = cls.resolve_custom_setting(feed, "feed") | |
| limit = cls.resolve_custom_setting(limit, "limit") | |
| freq = timeframe | |
| split = dt.split_freq_str(timeframe) | |
| if split is None: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| multiplier, unit = split | |
| if unit == "m": | |
| unit = TimeFrameUnit.Minute | |
| elif unit == "h": | |
| unit = TimeFrameUnit.Hour | |
| elif unit == "D": | |
| unit = TimeFrameUnit.Day | |
| elif unit == "W": | |
| unit = TimeFrameUnit.Week | |
| elif unit == "M": | |
| unit = TimeFrameUnit.Month | |
| else: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| timeframe = TimeFrame(multiplier, unit) | |
| if start is not None: | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") | |
| start_str = start.replace(tzinfo=None).isoformat("T") | |
| else: | |
| start_str = None | |
| if end is not None: | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") | |
| end_str = end.replace(tzinfo=None).isoformat("T") | |
| else: | |
| end_str = None | |
| if isinstance(client, CryptoHistoricalDataClient): | |
| request = CryptoBarsRequest( | |
| symbol_or_symbols=symbol, | |
| timeframe=timeframe, | |
| start=start_str, | |
| end=end_str, | |
| limit=limit, | |
| ) | |
| df = client.get_crypto_bars(request).df | |
| elif isinstance(client, StockHistoricalDataClient): | |
| request = StockBarsRequest( | |
| symbol_or_symbols=symbol, | |
| timeframe=timeframe, | |
| start=start_str, | |
| end=end_str, | |
| limit=limit, | |
| adjustment=adjustment, | |
| feed=feed, | |
| ) | |
| df = client.get_stock_bars(request).df | |
| else: | |
| raise TypeError(f"Invalid client of type {type(client)}") | |
| df = df.droplevel("symbol", axis=0) | |
| df.index = df.index.rename("Open time") | |
| df.rename( | |
| columns={ | |
| "open": "Open", | |
| "high": "High", | |
| "low": "Low", | |
| "close": "Close", | |
| "volume": "Volume", | |
| "trade_count": "Trade count", | |
| "vwap": "VWAP", | |
| }, | |
| inplace=True, | |
| ) | |
| if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: | |
| df = df.tz_localize("utc") | |
| if "Open" in df.columns: | |
| df["Open"] = df["Open"].astype(float) | |
| if "High" in df.columns: | |
| df["High"] = df["High"].astype(float) | |
| if "Low" in df.columns: | |
| df["Low"] = df["Low"].astype(float) | |
| if "Close" in df.columns: | |
| df["Close"] = df["Close"].astype(float) | |
| if "Volume" in df.columns: | |
| df["Volume"] = df["Volume"].astype(float) | |
| if "Trade count" in df.columns: | |
| df["Trade count"] = df["Trade count"].astype(int, errors="ignore") | |
| if "VWAP" in df.columns: | |
| df["VWAP"] = df["VWAP"].astype(float) | |
| if not df.empty: | |
| if start is not None: | |
| start = dt.to_timestamp(start, tz=df.index.tz) | |
| if df.index[0] < start: | |
| df = df[df.index >= start] | |
| if end is not None: | |
| end = dt.to_timestamp(end, tz=df.index.tz) | |
| if df.index[-1] >= end: | |
| df = df[df.index < end] | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/av.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `AVData`.""" | |
| import re | |
| import urllib.parse | |
| from functools import lru_cache | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.module_ import check_installed | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| from vectorbtpro.utils.warnings_ import warn | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from alpha_vantage.alphavantage import AlphaVantage as AlphaVantageT | |
| except ImportError: | |
| AlphaVantageT = "AlphaVantage" | |
| __all__ = [ | |
| "AVData", | |
| ] | |
| __pdoc__ = {} | |
| AVDataT = tp.TypeVar("AVDataT", bound="AVData") | |
| class AVData(RemoteData): | |
| """Data class for fetching from Alpha Vantage. | |
| See https://www.alphavantage.co/documentation/ for API. | |
| Apart of using https://github.com/RomelTorres/alpha_vantage package, this class can also | |
| parse the API documentation with `AVData.parse_api_meta` using `BeautifulSoup4` and build | |
| the API query based on this metadata (pass `use_parser=True`). | |
| This approach is the most flexible we can get since we can instantly react to Alpha Vantage's changes | |
| in the API. If the data provider changes its API documentation, you can always adapt the parsing | |
| procedure by overriding `AVData.parse_api_meta`. | |
| If parser still fails, you can disable parsing entirely and specify all information manually | |
| by setting `function` and disabling `match_params` | |
| See `AVData.fetch_symbol` for arguments. | |
| Usage: | |
| * Set up the API key globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.AVData.set_custom_settings( | |
| ... apikey="YOUR_KEY" | |
| ... ) | |
| ``` | |
| * Pull data: | |
| ```pycon | |
| >>> data = vbt.AVData.pull( | |
| ... "GOOGL", | |
| ... timeframe="1 day", | |
| ... ) | |
| >>> data = vbt.AVData.pull( | |
| ... "BTC_USD", | |
| ... timeframe="30 minutes", # premium? | |
| ... category="digital-currency", | |
| ... outputsize="full" | |
| ... ) | |
| >>> data = vbt.AVData.pull( | |
| ... "REAL_GDP", | |
| ... category="economic-indicators" | |
| ... ) | |
| >>> data = vbt.AVData.pull( | |
| ... "IBM", | |
| ... category="technical-indicators", | |
| ... function="STOCHRSI", | |
| ... params=dict(fastkperiod=14) | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.av") | |
| @classmethod | |
| def list_symbols(cls, keywords: str, apikey: tp.Optional[str] = None, sort: bool = True) -> tp.List[str]: | |
| """List all symbols.""" | |
| apikey = cls.resolve_custom_setting(apikey, "apikey") | |
| query = dict() | |
| query["function"] = "SYMBOL_SEARCH" | |
| query["keywords"] = keywords | |
| query["datatype"] = "csv" | |
| query["apikey"] = apikey | |
| url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(query) | |
| df = pd.read_csv(url) | |
| if sort: | |
| return sorted(dict.fromkeys(df["symbol"].tolist())) | |
| return list(dict.fromkeys(df["symbol"].tolist())) | |
| @classmethod | |
| @lru_cache() | |
| def parse_api_meta(cls) -> dict: | |
| """Parse API metadata from the documentation at https://www.alphavantage.co/documentation | |
| Cached class method. To avoid re-parsing the same metadata in different runtimes, save it manually.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("bs4") | |
| from bs4 import BeautifulSoup | |
| page = requests.get("https://www.alphavantage.co/documentation") | |
| soup = BeautifulSoup(page.content, "html.parser") | |
| api_meta = {} | |
| for section in soup.select("article section"): | |
| category = {} | |
| function = None | |
| function_args = dict(req_args=set(), opt_args=set()) | |
| for tag in section.find_all(True): | |
| if tag.name == "h6": | |
| if function is not None and tag.select("b")[0].getText().strip() == "API Parameters": | |
| category[function] = function_args | |
| function = None | |
| function_args = dict(req_args=set(), opt_args=set()) | |
| if tag.name == "b": | |
| b_text = tag.getText().strip() | |
| if b_text.startswith("❚ Required"): | |
| arg = tag.select("code")[0].getText().strip() | |
| function_args["req_args"].add(arg) | |
| if tag.name == "p": | |
| p_text = tag.getText().strip() | |
| if p_text.startswith("❚ Optional"): | |
| arg = tag.select("code")[0].getText().strip() | |
| function_args["opt_args"].add(arg) | |
| if tag.name == "code": | |
| code_text = tag.getText().strip() | |
| if code_text.startswith("function="): | |
| function = code_text.replace("function=", "") | |
| if function is not None: | |
| category[function] = function_args | |
| api_meta[section.select("h2")[0]["id"]] = category | |
| return api_meta | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| use_parser: tp.Optional[bool] = None, | |
| apikey: tp.Optional[str] = None, | |
| api_meta: tp.Optional[dict] = None, | |
| category: tp.Union[None, str, AlphaVantageT, tp.Type[AlphaVantageT]] = None, | |
| function: tp.Union[None, str, tp.Callable] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| adjusted: tp.Optional[bool] = None, | |
| extended: tp.Optional[bool] = None, | |
| slice: tp.Optional[str] = None, | |
| series_type: tp.Optional[str] = None, | |
| time_period: tp.Optional[int] = None, | |
| outputsize: tp.Optional[str] = None, | |
| match_params: tp.Optional[bool] = None, | |
| params: tp.KwargsLike = None, | |
| read_csv_kwargs: tp.KwargsLike = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.SymbolData: | |
| """Fetch a symbol from Alpha Vantage. | |
| If `use_parser` is False, or None and `alpha_vantage` is installed, uses the package. | |
| Otherwise, parses the API documentation and pulls data directly. | |
| See https://www.alphavantage.co/documentation/ for API endpoints and their parameters. | |
| !!! note | |
| Supports the CSV format only. | |
| Args: | |
| symbol (str): Symbol. | |
| May combine symbol/from_currency and market/to_currency using an underscore. | |
| use_parser (bool): Whether to use the parser instead of the `alpha_vantage` package. | |
| apikey (str): API key. | |
| api_meta (dict): API meta. | |
| If None, will use `AVData.parse_api_meta` if `function` is not provided | |
| or `match_params` is True. | |
| category (str or AlphaVantage): API category of your choice. | |
| Used if `function` is not provided or `match_params` is True. | |
| Supported are: | |
| * `alpha_vantage.alphavantage.AlphaVantage` instance, class, or class name | |
| * "time-series-data" or "time-series" | |
| * "fundamental-data" or "fundamentals" | |
| * "foreign-exchange", "forex", or "fx" | |
| * "digital-currency", "cryptocurrencies", "cryptocurrency", or "crypto" | |
| * "commodities" | |
| * "economic-indicators" | |
| * "technical-indicators" or "indicators" | |
| function (str or callable): API function of your choice. | |
| If None, will try to resolve it based on other arguments, such as `timeframe`, | |
| `adjusted`, and `extended`. Required for technical indicators, economic indicators, | |
| and fundamental data. | |
| See the keys in sub-dictionaries returned by `AVData.parse_api_meta`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| For time series, forex, and crypto, looks for interval type in the function's name. | |
| Defaults to "60min" if extended, otherwise to "daily". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| adjusted (bool): Whether to return time series adjusted by historical split and dividend events. | |
| extended (bool): Whether to return historical intraday time series for the trailing 2 years. | |
| slice (str): Slice of the trailing 2 years. | |
| series_type (str): The desired price type in the time series. | |
| time_period (int): Number of data points used to calculate each window value. | |
| outputsize (str): Output size. | |
| Supported are | |
| * "compact" that returns only the latest 100 data points | |
| * "full" that returns the full-length time series | |
| match_params (bool): Whether to match parameters with the ones required by the endpoint. | |
| Otherwise, uses only (resolved) `function`, `apikey`, `datatype="csv"`, and `params`. | |
| params: Additional keyword arguments passed as key/value pairs in the URL. | |
| read_csv_kwargs (dict): Keyword arguments passed to `pd.read_csv`. | |
| silence_warnings (bool): Whether to silence all warnings. | |
| For defaults, see `custom.av` in `vectorbtpro._settings.data`. | |
| """ | |
| use_parser = cls.resolve_custom_setting(use_parser, "use_parser") | |
| apikey = cls.resolve_custom_setting(apikey, "apikey") | |
| api_meta = cls.resolve_custom_setting(api_meta, "api_meta") | |
| category = cls.resolve_custom_setting(category, "category") | |
| function = cls.resolve_custom_setting(function, "function") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| adjusted = cls.resolve_custom_setting(adjusted, "adjusted") | |
| extended = cls.resolve_custom_setting(extended, "extended") | |
| slice = cls.resolve_custom_setting(slice, "slice") | |
| series_type = cls.resolve_custom_setting(series_type, "series_type") | |
| time_period = cls.resolve_custom_setting(time_period, "time_period") | |
| outputsize = cls.resolve_custom_setting(outputsize, "outputsize") | |
| read_csv_kwargs = cls.resolve_custom_setting(read_csv_kwargs, "read_csv_kwargs", merge=True) | |
| match_params = cls.resolve_custom_setting(match_params, "match_params") | |
| params = cls.resolve_custom_setting(params, "params", merge=True) | |
| silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings") | |
| if use_parser is None: | |
| if api_meta is None and check_installed("alpha_vantage"): | |
| use_parser = False | |
| else: | |
| use_parser = True | |
| if not use_parser: | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("alpha_vantage") | |
| if use_parser and api_meta is None and (function is None or match_params): | |
| if not silence_warnings and cls.parse_api_meta.cache_info().misses == 0: | |
| warn("Parsing API documentation...") | |
| try: | |
| api_meta = cls.parse_api_meta() | |
| except Exception as e: | |
| raise ValueError("Can't fetch/parse the API documentation. Specify function and disable match_params.") | |
| freq = timeframe | |
| interval = None | |
| interval_type = None | |
| if timeframe is not None: | |
| if not isinstance(timeframe, str): | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| split = dt.split_freq_str(timeframe) | |
| if split is None: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| multiplier, unit = split | |
| if unit == "m": | |
| interval = str(multiplier) + "min" | |
| interval_type = "intraday" | |
| elif unit == "h": | |
| interval = str(60 * multiplier) + "min" | |
| interval_type = "intraday" | |
| elif unit == "D": | |
| interval = "daily" | |
| interval_type = "daily" | |
| elif unit == "W": | |
| interval = "weekly" | |
| interval_type = "weekly" | |
| elif unit == "M": | |
| interval = "monthly" | |
| interval_type = "monthly" | |
| elif unit == "Q": | |
| interval = "quarterly" | |
| interval_type = "quarterly" | |
| elif unit == "Y": | |
| interval = "annual" | |
| interval_type = "annual" | |
| if interval is None and multiplier > 1: | |
| raise ValueError("Multipliers are supported only for intraday timeframes") | |
| else: | |
| if extended: | |
| interval_type = "intraday" | |
| interval = "60min" | |
| else: | |
| interval_type = "daily" | |
| interval = "daily" | |
| if category is not None: | |
| if isinstance(category, str): | |
| if category.lower() in ("time-series-data", "time-series", "timeseries"): | |
| if use_parser: | |
| category = "time-series-data" | |
| else: | |
| from alpha_vantage.timeseries import TimeSeries | |
| category = TimeSeries | |
| elif category.lower() in ("fundamentals", "fundamental-data", "fundamentaldata"): | |
| if use_parser: | |
| category = "fundamentals" | |
| else: | |
| from alpha_vantage.fundamentaldata import FundamentalData | |
| category = FundamentalData | |
| elif category.lower() in ("fx", "forex", "foreign-exchange", "foreignexchange"): | |
| if use_parser: | |
| category = "fx" | |
| else: | |
| from alpha_vantage.foreignexchange import ForeignExchange | |
| category = ForeignExchange | |
| elif category.lower() in ("digital-currency", "cryptocurrencies", "cryptocurrency", "crypto"): | |
| if use_parser: | |
| category = "digital-currency" | |
| else: | |
| from alpha_vantage.cryptocurrencies import CryptoCurrencies | |
| category = CryptoCurrencies | |
| elif category.lower() in ("commodities",): | |
| if use_parser: | |
| category = "commodities" | |
| else: | |
| raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.") | |
| elif category.lower() in ("economic-indicators",): | |
| if use_parser: | |
| category = "economic-indicators" | |
| else: | |
| raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.") | |
| elif category.lower() in ("technical-indicators", "techindicators", "indicators"): | |
| if use_parser: | |
| category = "technical-indicators" | |
| else: | |
| from alpha_vantage.techindicators import TechIndicators | |
| category = TechIndicators | |
| else: | |
| raise ValueError(f"Invalid category: '{category}'") | |
| else: | |
| if use_parser: | |
| raise TypeError("Category must be a string") | |
| else: | |
| from alpha_vantage.alphavantage import AlphaVantage | |
| if isinstance(category, type): | |
| if not issubclass(category, AlphaVantage): | |
| raise TypeError("Category must be a subclass of AlphaVantage") | |
| elif not isinstance(category, AlphaVantage): | |
| raise TypeError("Category must be an instance of AlphaVantage") | |
| if use_parser: | |
| if function is None: | |
| if category is not None: | |
| if category in ("commodities", "economic-indicators"): | |
| function = symbol | |
| if function is None: | |
| if category is None: | |
| category = "time-series-data" | |
| if category in ("fundamentals", "technical-indicators"): | |
| raise ValueError("Function is required") | |
| adjusted_in_functions = False | |
| extended_in_functions = False | |
| matched_functions = [] | |
| for k in api_meta[category]: | |
| if interval_type is None or interval_type.upper() in k: | |
| if "ADJUSTED" in k: | |
| adjusted_in_functions = True | |
| if "EXTENDED" in k: | |
| extended_in_functions = True | |
| matched_functions.append(k) | |
| if adjusted_in_functions: | |
| matched_functions = [ | |
| k | |
| for k in matched_functions | |
| if (adjusted and "ADJUSTED" in k) or (not adjusted and "ADJUSTED" not in k) | |
| ] | |
| if extended_in_functions: | |
| matched_functions = [ | |
| k | |
| for k in matched_functions | |
| if (extended and "EXTENDED" in k) or (not extended and "EXTENDED" not in k) | |
| ] | |
| if len(matched_functions) == 0: | |
| raise ValueError("No functions satisfy the requirements") | |
| if len(matched_functions) > 1: | |
| raise ValueError("More than one function satisfies the requirements") | |
| function = matched_functions[0] | |
| if match_params: | |
| if function is not None and category is None: | |
| category = None | |
| for k, v in api_meta.items(): | |
| if function in v: | |
| category = k | |
| break | |
| if category is None: | |
| raise ValueError("Category is required") | |
| req_args = api_meta[category][function]["req_args"] | |
| opt_args = api_meta[category][function]["opt_args"] | |
| args = set(req_args) | set(opt_args) | |
| matched_params = dict() | |
| matched_params["function"] = function | |
| matched_params["datatype"] = "csv" | |
| matched_params["apikey"] = apikey | |
| if "symbol" in args and "market" in args: | |
| matched_params["symbol"] = symbol.split("_")[0] | |
| matched_params["market"] = symbol.split("_")[1] | |
| elif "from_" in args and "to_currency" in args: | |
| matched_params["from_currency"] = symbol.split("_")[0] | |
| matched_params["to_currency"] = symbol.split("_")[1] | |
| elif "from_currency" in args and "to_currency" in args: | |
| matched_params["from_currency"] = symbol.split("_")[0] | |
| matched_params["to_currency"] = symbol.split("_")[1] | |
| elif "symbol" in args: | |
| matched_params["symbol"] = symbol | |
| if "interval" in args: | |
| matched_params["interval"] = interval | |
| if "adjusted" in args: | |
| matched_params["adjusted"] = adjusted | |
| if "extended" in args: | |
| matched_params["extended"] = extended | |
| if "extended_hours" in args: | |
| matched_params["extended_hours"] = extended | |
| if "slice" in args: | |
| matched_params["slice"] = slice | |
| if "series_type" in args: | |
| matched_params["series_type"] = series_type | |
| if "time_period" in args: | |
| matched_params["time_period"] = time_period | |
| if "outputsize" in args: | |
| matched_params["outputsize"] = outputsize | |
| for k, v in params.items(): | |
| if k in args: | |
| matched_params[k] = v | |
| else: | |
| raise ValueError(f"Function '{function}' does not expect parameter '{k}'") | |
| for arg in req_args: | |
| if arg not in matched_params: | |
| raise ValueError(f"Function '{function}' requires parameter '{arg}'") | |
| else: | |
| matched_params = dict(params) | |
| matched_params["function"] = function | |
| matched_params["apikey"] = apikey | |
| matched_params["datatype"] = "csv" | |
| url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(matched_params) | |
| df = pd.read_csv(url, **read_csv_kwargs) | |
| else: | |
| from alpha_vantage.alphavantage import AlphaVantage | |
| from alpha_vantage.timeseries import TimeSeries | |
| from alpha_vantage.fundamentaldata import FundamentalData | |
| from alpha_vantage.foreignexchange import ForeignExchange | |
| from alpha_vantage.cryptocurrencies import CryptoCurrencies | |
| from alpha_vantage.techindicators import TechIndicators | |
| if isinstance(category, type) and issubclass(category, AlphaVantage): | |
| category = category(key=apikey, output_format="pandas") | |
| if function is None: | |
| if category is None: | |
| category = TimeSeries(key=apikey, output_format="pandas") | |
| if isinstance(category, (TechIndicators, FundamentalData)): | |
| raise ValueError("Function is required") | |
| adjusted_in_methods = False | |
| extended_in_methods = False | |
| matched_methods = [] | |
| for k in dir(category): | |
| if interval_type is None or interval_type in k: | |
| if "adjusted" in k: | |
| adjusted_in_methods = True | |
| if "extended" in k: | |
| extended_in_methods = True | |
| matched_methods.append(k) | |
| if adjusted_in_methods: | |
| matched_methods = [ | |
| k | |
| for k in matched_methods | |
| if (adjusted and "adjusted" in k) or (not adjusted and "adjusted" not in k) | |
| ] | |
| if extended_in_methods: | |
| matched_methods = [ | |
| k | |
| for k in matched_methods | |
| if (extended and "extended" in k) or (not extended and "extended" not in k) | |
| ] | |
| if len(matched_methods) == 0: | |
| raise ValueError("No methods satisfy the requirements") | |
| if len(matched_methods) > 1: | |
| raise ValueError("More than one method satisfies the requirements") | |
| function = matched_methods[0] | |
| if isinstance(function, str): | |
| function = function.lower() | |
| if not function.startswith("get_"): | |
| function = "get_" + function | |
| if category is not None: | |
| function = getattr(category, function) | |
| else: | |
| categories = [ | |
| TimeSeries, | |
| FundamentalData, | |
| ForeignExchange, | |
| CryptoCurrencies, | |
| TechIndicators, | |
| ] | |
| matched_methods = [] | |
| for category in categories: | |
| if function in dir(category): | |
| matched_methods.append(getattr(category, function)) | |
| if len(matched_methods) == 0: | |
| raise ValueError("No methods satisfy the requirements") | |
| if len(matched_methods) > 1: | |
| raise ValueError("More than one method satisfies the requirements") | |
| function = matched_methods[0] | |
| if match_params: | |
| args = set(get_func_arg_names(function)) | |
| matched_params = dict() | |
| if "symbol" in args and "market" in args: | |
| matched_params["symbol"] = symbol.split("_")[0] | |
| matched_params["market"] = symbol.split("_")[1] | |
| elif "from_" in args and "to_currency" in args: | |
| matched_params["from_currency"] = symbol.split("_")[0] | |
| matched_params["to_currency"] = symbol.split("_")[1] | |
| elif "from_currency" in args and "to_currency" in args: | |
| matched_params["from_currency"] = symbol.split("_")[0] | |
| matched_params["to_currency"] = symbol.split("_")[1] | |
| elif "symbol" in args: | |
| matched_params["symbol"] = symbol | |
| if "interval" in args: | |
| matched_params["interval"] = interval | |
| if "adjusted" in args: | |
| matched_params["adjusted"] = adjusted | |
| if "extended" in args: | |
| matched_params["extended"] = extended | |
| if "extended_hours" in args: | |
| matched_params["extended_hours"] = extended | |
| if "slice" in args: | |
| matched_params["slice"] = slice | |
| if "series_type" in args: | |
| matched_params["series_type"] = series_type | |
| if "time_period" in args: | |
| matched_params["time_period"] = time_period | |
| if "outputsize" in args: | |
| matched_params["outputsize"] = outputsize | |
| else: | |
| matched_params = dict(params) | |
| df, df_metadata = function(**matched_params) | |
| for k, v in df_metadata.items(): | |
| if "Time Zone" in k: | |
| if tz is None: | |
| if v.endswith(" Time"): | |
| v = v[: -len(" Time")] | |
| tz = v | |
| df.index.name = None | |
| new_columns = [] | |
| for c in df.columns: | |
| new_c = re.sub(r"^\d+\w*\.\s*", "", c) | |
| new_c = new_c[0].title() + new_c[1:] | |
| if new_c.endswith(" (USD)"): | |
| new_c = new_c[: -len(" (USD)")] | |
| new_columns.append(new_c) | |
| df = df.rename(columns=dict(zip(df.columns, new_columns))) | |
| df = df.loc[:, ~df.columns.duplicated()] | |
| for c in df.columns: | |
| if df[c].dtype == "O": | |
| df[c] = df[c].replace({".": np.nan}) | |
| df = df.apply(pd.to_numeric, errors="ignore") | |
| if not df.empty and df.index[0] > df.index[1]: | |
| df = df.iloc[::-1] | |
| if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None and tz is not None: | |
| df = df.tz_localize(tz) | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/bento.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `BentoData`.""" | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from databento import Historical as HistoricalT | |
| except ImportError: | |
| HistoricalT = "Historical" | |
| __all__ = [ | |
| "BentoData", | |
| ] | |
| class BentoData(RemoteData): | |
| """Data class for fetching from Databento. | |
| See https://github.com/databento/databento-python for API. | |
| See `BentoData.fetch_symbol` for arguments. | |
| Usage: | |
| * Set up the API key globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.BentoData.set_custom_settings( | |
| ... client_config=dict( | |
| ... key="YOUR_KEY" | |
| ... ) | |
| ... ) | |
| ``` | |
| * Pull data: | |
| ```pycon | |
| >>> data = vbt.BentoData.pull( | |
| ... "AAPL", | |
| ... dataset="XNAS.ITCH" | |
| ... ) | |
| ``` | |
| ```pycon | |
| >>> data = vbt.BentoData.pull( | |
| ... "AAPL", | |
| ... dataset="XNAS.ITCH", | |
| ... timeframe="hourly", | |
| ... start="one week ago" | |
| ... ) | |
| ``` | |
| ```pycon | |
| >>> data = vbt.BentoData.pull( | |
| ... "ES.FUT", | |
| ... dataset="GLBX.MDP3", | |
| ... stype_in="parent", | |
| ... schema="mbo", | |
| ... start="2022-06-10T14:30", | |
| ... end="2022-06-11", | |
| ... limit=1000 | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.bento") | |
| @classmethod | |
| def resolve_client(cls, client: tp.Optional[HistoricalT] = None, **client_config) -> HistoricalT: | |
| """Resolve the client. | |
| If provided, must be of the type `databento.historical.client.Historical`. | |
| Otherwise, will be created using `client_config`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("databento") | |
| from databento import Historical | |
| client = cls.resolve_custom_setting(client, "client") | |
| if client_config is None: | |
| client_config = {} | |
| has_client_config = len(client_config) > 0 | |
| client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) | |
| if client is None: | |
| client = Historical(**client_config) | |
| elif has_client_config: | |
| raise ValueError("Cannot apply client_config to already initialized client") | |
| return client | |
| @classmethod | |
| def get_cost(cls, symbols: tp.MaybeSymbols, **kwargs) -> float: | |
| """Get the cost of calling `BentoData.fetch_symbol` on one or more symbols.""" | |
| if isinstance(symbols, str): | |
| symbols = [symbols] | |
| costs = [] | |
| for symbol in symbols: | |
| client, params = cls.fetch_symbol(symbol, **kwargs, return_params=True) | |
| cost_arg_names = get_func_arg_names(client.metadata.get_cost) | |
| for k in list(params.keys()): | |
| if k not in cost_arg_names: | |
| del params[k] | |
| costs.append(client.metadata.get_cost(**params, mode="historical")) | |
| return sum(costs) | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| client: tp.Optional[HistoricalT] = None, | |
| client_config: tp.KwargsLike = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| resolve_dates: tp.Optional[bool] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| dataset: tp.Optional[str] = None, | |
| schema: tp.Optional[str] = None, | |
| return_params: bool = False, | |
| df_kwargs: tp.KwargsLike = None, | |
| **params, | |
| ) -> tp.Union[float, tp.SymbolData]: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Databento. | |
| Args: | |
| symbol (str): Symbol. | |
| Symbol can be in the `DATASET:SYMBOL` format if `dataset` is None. | |
| client (binance.client.Client): Client. | |
| See `BentoData.resolve_client`. | |
| client_config (dict): Client config. | |
| See `BentoData.resolve_client`. | |
| start (any): Start datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): End datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| resolve_dates (bool): Whether to resolve `start` and `end`, or pass them as they are. | |
| timeframe (str): Timeframe to create `schema` from. | |
| Allows human-readable strings such as "1 minute". | |
| If `timeframe` and `schema` are both not None, will raise an error. | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| dataset (str): See `databento.historical.client.Historical.get_range`. | |
| schema (str): See `databento.historical.client.Historical.get_range`. | |
| return_params (bool): Whether to return the client and (final) parameters instead of data. | |
| Used by `BentoData.get_cost`. | |
| df_kwargs (dict): Keyword arguments passed to `databento.common.dbnstore.DBNStore.to_df`. | |
| **params: Keyword arguments passed to `databento.historical.client.Historical.get_range`. | |
| For defaults, see `custom.bento` in `vectorbtpro._settings.data`. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("databento") | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| resolve_dates = cls.resolve_custom_setting(resolve_dates, "resolve_dates") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| dataset = cls.resolve_custom_setting(dataset, "dataset") | |
| schema = cls.resolve_custom_setting(schema, "schema") | |
| params = cls.resolve_custom_setting(params, "params", merge=True) | |
| df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True) | |
| if dataset is None: | |
| if ":" in symbol: | |
| dataset, symbol = symbol.split(":") | |
| if timeframe is None and schema is None: | |
| schema = "ohlcv-1d" | |
| freq = "1d" | |
| elif timeframe is not None: | |
| freq = timeframe | |
| split = dt.split_freq_str(timeframe) | |
| if split is not None: | |
| multiplier, unit = split | |
| timeframe = str(multiplier) + unit | |
| if schema is None or schema.lower() == "ohlcv": | |
| schema = f"ohlcv-{timeframe}" | |
| else: | |
| raise ValueError("Timeframe cannot be used together with schema") | |
| else: | |
| if schema.startswith("ohlcv-"): | |
| freq = schema[len("ohlcv-") :] | |
| else: | |
| freq = None | |
| if resolve_dates: | |
| dataset_range = client.metadata.get_dataset_range(dataset) | |
| if "start_date" in dataset_range: | |
| start_date = dt.to_tzaware_timestamp(dataset_range["start_date"], naive_tz="utc", tz="utc") | |
| else: | |
| start_date = dt.to_tzaware_timestamp(dataset_range["start"], naive_tz="utc", tz="utc") | |
| if "end_date" in dataset_range: | |
| end_date = dt.to_tzaware_timestamp(dataset_range["end_date"], naive_tz="utc", tz="utc") | |
| else: | |
| end_date = dt.to_tzaware_timestamp(dataset_range["end"], naive_tz="utc", tz="utc") | |
| if start is not None: | |
| start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz="utc") | |
| if start < start_date: | |
| start = start_date | |
| else: | |
| start = start_date | |
| if end is not None: | |
| end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz="utc") | |
| if end > end_date: | |
| end = end_date | |
| else: | |
| end = end_date | |
| if start.floor("d") == start: | |
| start = start.date().isoformat() | |
| else: | |
| start = start.isoformat() | |
| if end.floor("d") == end: | |
| end = end.date().isoformat() | |
| else: | |
| end = end.isoformat() | |
| params = merge_dicts( | |
| dict( | |
| dataset=dataset, | |
| start=start, | |
| end=end, | |
| symbols=symbol, | |
| schema=schema, | |
| ), | |
| params, | |
| ) | |
| if return_params: | |
| return client, params | |
| df = client.timeseries.get_range(**params).to_df(**df_kwargs) | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/binance.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `BinanceData`.""" | |
| import time | |
| import traceback | |
| from functools import partial | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.generic import nb as generic_nb | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig | |
| from vectorbtpro.utils.enum_ import map_enum_fields | |
| from vectorbtpro.utils.pbar import ProgressBar | |
| from vectorbtpro.utils.warnings_ import warn | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from binance.client import Client as BinanceClientT | |
| except ImportError: | |
| BinanceClientT = "BinanceClient" | |
| __all__ = [ | |
| "BinanceData", | |
| ] | |
| __pdoc__ = {} | |
| BinanceDataT = tp.TypeVar("BinanceDataT", bound="BinanceData") | |
| class BinanceData(RemoteData): | |
| """Data class for fetching from Binance. | |
| See https://github.com/sammchardy/python-binance for API. | |
| See `BinanceData.fetch_symbol` for arguments. | |
| !!! note | |
| If you are using an exchange from the US, Japan or other TLD then make sure pass `tld="us"` | |
| in `client_config` when creating the client. | |
| Usage: | |
| * Set up the API key globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.BinanceData.set_custom_settings( | |
| ... client_config=dict( | |
| ... api_key="YOUR_KEY", | |
| ... api_secret="YOUR_SECRET" | |
| ... ) | |
| ... ) | |
| ``` | |
| * Pull data: | |
| ```pycon | |
| >>> data = vbt.BinanceData.pull( | |
| ... "BTCUSDT", | |
| ... start="2020-01-01", | |
| ... end="2021-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.binance") | |
| _feature_config: tp.ClassVar[Config] = HybridConfig( | |
| { | |
| "Quote volume": dict( | |
| resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( | |
| resampler, | |
| generic_nb.sum_reduce_nb, | |
| ) | |
| ), | |
| "Taker base volume": dict( | |
| resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( | |
| resampler, | |
| generic_nb.sum_reduce_nb, | |
| ) | |
| ), | |
| "Taker quote volume": dict( | |
| resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( | |
| resampler, | |
| generic_nb.sum_reduce_nb, | |
| ) | |
| ), | |
| } | |
| ) | |
| @property | |
| def feature_config(self) -> Config: | |
| return self._feature_config | |
| @classmethod | |
| def resolve_client(cls, client: tp.Optional[BinanceClientT] = None, **client_config) -> BinanceClientT: | |
| """Resolve the client. | |
| If provided, must be of the type `binance.client.Client`. | |
| Otherwise, will be created using `client_config`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("binance") | |
| from binance.client import Client | |
| client = cls.resolve_custom_setting(client, "client") | |
| if client_config is None: | |
| client_config = {} | |
| has_client_config = len(client_config) > 0 | |
| client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) | |
| if client is None: | |
| client = Client(**client_config) | |
| elif has_client_config: | |
| raise ValueError("Cannot apply client_config to already initialized client") | |
| return client | |
| @classmethod | |
| def list_symbols( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| client: tp.Optional[BinanceClientT] = None, | |
| client_config: tp.KwargsLike = None, | |
| ) -> tp.List[str]: | |
| """List all symbols. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| all_symbols = [] | |
| for dct in client.get_exchange_info()["symbols"]: | |
| symbol = dct["symbol"] | |
| if pattern is not None: | |
| if not cls.key_match(symbol, pattern, use_regex=use_regex): | |
| continue | |
| all_symbols.append(symbol) | |
| if sort: | |
| return sorted(dict.fromkeys(all_symbols)) | |
| return list(dict.fromkeys(all_symbols)) | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| client: tp.Optional[BinanceClientT] = None, | |
| client_config: tp.KwargsLike = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| klines_type: tp.Union[None, int, str] = None, | |
| limit: tp.Optional[int] = None, | |
| delay: tp.Optional[float] = None, | |
| show_progress: tp.Optional[bool] = None, | |
| pbar_kwargs: tp.KwargsLike = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| **get_klines_kwargs, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Binance. | |
| Args: | |
| symbol (str): Symbol. | |
| client (binance.client.Client): Client. | |
| See `BinanceData.resolve_client`. | |
| client_config (dict): Client config. | |
| See `BinanceData.resolve_client`. | |
| start (any): Start datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): End datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| klines_type (int or str): Kline type. | |
| See `binance.enums.HistoricalKlinesType`. Supports strings. | |
| limit (int): The maximum number of returned items. | |
| delay (float): Time to sleep after each request (in seconds). | |
| show_progress (bool): Whether to show the progress bar. | |
| pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`. | |
| silence_warnings (bool): Whether to silence all warnings. | |
| **get_klines_kwargs: Keyword arguments passed to `binance.client.Client.get_klines`. | |
| For defaults, see `custom.binance` in `vectorbtpro._settings.data`. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("binance") | |
| from binance.enums import HistoricalKlinesType | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| klines_type = cls.resolve_custom_setting(klines_type, "klines_type") | |
| if isinstance(klines_type, str): | |
| klines_type = map_enum_fields(klines_type, HistoricalKlinesType) | |
| if isinstance(klines_type, int): | |
| klines_type = {i.value: i for i in HistoricalKlinesType}[klines_type] | |
| limit = cls.resolve_custom_setting(limit, "limit") | |
| delay = cls.resolve_custom_setting(delay, "delay") | |
| show_progress = cls.resolve_custom_setting(show_progress, "show_progress") | |
| pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True) | |
| if "bar_id" not in pbar_kwargs: | |
| pbar_kwargs["bar_id"] = "binance" | |
| silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings") | |
| get_klines_kwargs = cls.resolve_custom_setting(get_klines_kwargs, "get_klines_kwargs", merge=True) | |
| # Prepare parameters | |
| freq = timeframe | |
| split = dt.split_freq_str(timeframe) | |
| if split is not None: | |
| multiplier, unit = split | |
| if unit == "D": | |
| unit = "d" | |
| elif unit == "W": | |
| unit = "w" | |
| timeframe = str(multiplier) + unit | |
| if start is not None: | |
| start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) | |
| first_valid_ts = client._get_earliest_valid_timestamp(symbol, timeframe, klines_type) | |
| start_ts = max(start_ts, first_valid_ts) | |
| else: | |
| start_ts = None | |
| prev_end_ts = None | |
| if end is not None: | |
| end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) | |
| else: | |
| end_ts = None | |
| def _ts_to_str(ts: tp.Optional[int]) -> str: | |
| if ts is None: | |
| return "?" | |
| return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe) | |
| def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool: | |
| if start_ts is not None: | |
| if d[0] < start_ts: | |
| return False | |
| if _prev_end_ts is not None: | |
| if d[0] <= _prev_end_ts: | |
| return False | |
| if end_ts is not None: | |
| if d[0] >= end_ts: | |
| return False | |
| return True | |
| # Iteratively collect the data | |
| data = [] | |
| try: | |
| with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar: | |
| pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts))) | |
| while True: | |
| # Fetch the klines for the next timeframe | |
| next_data = client._klines( | |
| symbol=symbol, | |
| interval=timeframe, | |
| limit=limit, | |
| startTime=start_ts if prev_end_ts is None else prev_end_ts, | |
| endTime=end_ts, | |
| klines_type=klines_type, | |
| **get_klines_kwargs, | |
| ) | |
| next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data)) | |
| # Update the timestamps and the progress bar | |
| if not len(next_data): | |
| break | |
| data += next_data | |
| if start_ts is None: | |
| start_ts = next_data[0][0] | |
| pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0]))) | |
| pbar.update() | |
| prev_end_ts = next_data[-1][0] | |
| if end_ts is not None and prev_end_ts >= end_ts: | |
| break | |
| if delay is not None: | |
| time.sleep(delay) # be kind to api | |
| except Exception as e: | |
| if not silence_warnings: | |
| warn(traceback.format_exc()) | |
| warn( | |
| f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. " | |
| "Use update() method to fetch missing data." | |
| ) | |
| # Convert data to a DataFrame | |
| df = pd.DataFrame( | |
| data, | |
| columns=[ | |
| "Open time", | |
| "Open", | |
| "High", | |
| "Low", | |
| "Close", | |
| "Volume", | |
| "Close time", | |
| "Quote volume", | |
| "Trade count", | |
| "Taker base volume", | |
| "Taker quote volume", | |
| "Ignore", | |
| ], | |
| ) | |
| df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True) | |
| df["Open"] = df["Open"].astype(float) | |
| df["High"] = df["High"].astype(float) | |
| df["Low"] = df["Low"].astype(float) | |
| df["Close"] = df["Close"].astype(float) | |
| df["Volume"] = df["Volume"].astype(float) | |
| df["Quote volume"] = df["Quote volume"].astype(float) | |
| df["Trade count"] = df["Trade count"].astype(int, errors="ignore") | |
| df["Taker base volume"] = df["Taker base volume"].astype(float) | |
| df["Taker quote volume"] = df["Taker quote volume"].astype(float) | |
| del df["Open time"] | |
| del df["Close time"] | |
| del df["Ignore"] | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| BinanceData.override_feature_config_doc(__pdoc__) | |
| </file> | |
| <file path="data/custom/ccxt.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `CCXTData`.""" | |
| import time | |
| import traceback | |
| from functools import wraps, partial | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.pbar import ProgressBar | |
| from vectorbtpro.utils.warnings_ import warn | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from ccxt.base.exchange import Exchange as CCXTExchangeT | |
| except ImportError: | |
| CCXTExchangeT = "CCXTExchange" | |
| __all__ = [ | |
| "CCXTData", | |
| ] | |
| __pdoc__ = {} | |
| class CCXTData(RemoteData): | |
| """Data class for fetching using CCXT. | |
| See https://github.com/ccxt/ccxt for API. | |
| See `CCXTData.fetch_symbol` for arguments. | |
| Usage: | |
| * Set up the API key globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.CCXTData.set_exchange_settings( | |
| ... exchange_name="binance", | |
| ... populate_=True, | |
| ... exchange_config=dict( | |
| ... apiKey="YOUR_KEY", | |
| ... secret="YOUR_SECRET" | |
| ... ) | |
| ... ) | |
| ``` | |
| * Pull data: | |
| ```pycon | |
| >>> data = vbt.CCXTData.pull( | |
| ... "BTCUSDT", | |
| ... exchange="binance", | |
| ... start="2020-01-01", | |
| ... end="2021-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.ccxt") | |
| @classmethod | |
| def get_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> dict: | |
| """`CCXTData.get_custom_settings` with `sub_path=exchange_name`.""" | |
| if exchange_name is not None: | |
| sub_path = "exchanges." + exchange_name | |
| else: | |
| sub_path = None | |
| return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def has_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool: | |
| """`CCXTData.has_custom_settings` with `sub_path=exchange_name`.""" | |
| if exchange_name is not None: | |
| sub_path = "exchanges." + exchange_name | |
| else: | |
| sub_path = None | |
| return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def get_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any: | |
| """`CCXTData.get_custom_setting` with `sub_path=exchange_name`.""" | |
| if exchange_name is not None: | |
| sub_path = "exchanges." + exchange_name | |
| else: | |
| sub_path = None | |
| return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def has_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool: | |
| """`CCXTData.has_custom_setting` with `sub_path=exchange_name`.""" | |
| if exchange_name is not None: | |
| sub_path = "exchanges." + exchange_name | |
| else: | |
| sub_path = None | |
| return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def resolve_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any: | |
| """`CCXTData.resolve_custom_setting` with `sub_path=exchange_name`.""" | |
| if exchange_name is not None: | |
| sub_path = "exchanges." + exchange_name | |
| else: | |
| sub_path = None | |
| return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def set_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> None: | |
| """`CCXTData.set_custom_settings` with `sub_path=exchange_name`.""" | |
| if exchange_name is not None: | |
| sub_path = "exchanges." + exchange_name | |
| else: | |
| sub_path = None | |
| cls.set_custom_settings(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def list_symbols( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| exchange: tp.Union[None, str, CCXTExchangeT] = None, | |
| exchange_config: tp.Optional[tp.KwargsLike] = None, | |
| ) -> tp.List[str]: | |
| """List all symbols. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" | |
| if exchange_config is None: | |
| exchange_config = {} | |
| exchange = cls.resolve_exchange(exchange=exchange, **exchange_config) | |
| all_symbols = [] | |
| for symbol in exchange.load_markets(): | |
| if pattern is not None: | |
| if not cls.key_match(symbol, pattern, use_regex=use_regex): | |
| continue | |
| all_symbols.append(symbol) | |
| if sort: | |
| return sorted(dict.fromkeys(all_symbols)) | |
| return list(dict.fromkeys(all_symbols)) | |
| @classmethod | |
| def resolve_exchange( | |
| cls, | |
| exchange: tp.Union[None, str, CCXTExchangeT] = None, | |
| **exchange_config, | |
| ) -> CCXTExchangeT: | |
| """Resolve the exchange. | |
| If provided, must be of the type `ccxt.base.exchange.Exchange`. | |
| Otherwise, will be created using `exchange_config`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("ccxt") | |
| import ccxt | |
| exchange = cls.resolve_exchange_setting(exchange, "exchange") | |
| if exchange is None: | |
| exchange = "binance" | |
| if isinstance(exchange, str): | |
| exchange = exchange.lower() | |
| exchange_name = exchange | |
| elif isinstance(exchange, ccxt.Exchange): | |
| exchange_name = type(exchange).__name__ | |
| else: | |
| raise ValueError(f"Unknown exchange of type {type(exchange)}") | |
| if exchange_config is None: | |
| exchange_config = {} | |
| has_exchange_config = len(exchange_config) > 0 | |
| exchange_config = cls.resolve_exchange_setting( | |
| exchange_config, "exchange_config", merge=True, exchange_name=exchange_name | |
| ) | |
| if isinstance(exchange, str): | |
| if not hasattr(ccxt, exchange): | |
| raise ValueError(f"Exchange '{exchange}' not found in CCXT") | |
| exchange = getattr(ccxt, exchange)(exchange_config) | |
| else: | |
| if has_exchange_config: | |
| raise ValueError("Cannot apply config after instantiation of the exchange") | |
| return exchange | |
| @staticmethod | |
| def _find_earliest_date( | |
| fetch_func: tp.Callable, | |
| start: tp.DatetimeLike = 0, | |
| end: tp.DatetimeLike = "now", | |
| tz: tp.TimezoneLike = None, | |
| for_internal_use: bool = False, | |
| ) -> tp.Optional[pd.Timestamp]: | |
| """Find the earliest date using binary search.""" | |
| if start is not None: | |
| start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) | |
| fetched_data = fetch_func(start_ts, 1) | |
| if for_internal_use and len(fetched_data) > 0: | |
| return pd.Timestamp(start_ts, unit="ms", tz="utc") | |
| else: | |
| fetched_data = [] | |
| if len(fetched_data) == 0 and start != 0: | |
| fetched_data = fetch_func(0, 1) | |
| if for_internal_use and len(fetched_data) > 0: | |
| return pd.Timestamp(0, unit="ms", tz="utc") | |
| if len(fetched_data) == 0: | |
| if start is not None: | |
| start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) | |
| else: | |
| start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(0, naive_tz=tz, tz="utc")) | |
| start_ts = start_ts - start_ts % 86400000 | |
| if end is not None: | |
| end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) | |
| else: | |
| end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime("now", naive_tz=tz, tz="utc")) | |
| end_ts = end_ts - end_ts % 86400000 + 86400000 | |
| start_time = start_ts | |
| end_time = end_ts | |
| while True: | |
| mid_time = (start_time + end_time) // 2 | |
| mid_time = mid_time - mid_time % 86400000 | |
| if mid_time == start_time: | |
| break | |
| _fetched_data = fetch_func(mid_time, 1) | |
| if len(_fetched_data) == 0: | |
| start_time = mid_time | |
| else: | |
| end_time = mid_time | |
| fetched_data = _fetched_data | |
| if len(fetched_data) > 0: | |
| return pd.Timestamp(fetched_data[0][0], unit="ms", tz="utc") | |
| return None | |
| @classmethod | |
| def find_earliest_date(cls, symbol: str, for_internal_use: bool = False, **kwargs) -> tp.Optional[pd.Timestamp]: | |
| """Find the earliest date using binary search. | |
| See `CCXTData.fetch_symbol` for arguments.""" | |
| return cls._find_earliest_date( | |
| **cls.fetch_symbol(symbol, return_fetch_method=True, **kwargs), | |
| for_internal_use=for_internal_use, | |
| ) | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| exchange: tp.Union[None, str, CCXTExchangeT] = None, | |
| exchange_config: tp.Optional[tp.KwargsLike] = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| find_earliest_date: tp.Optional[bool] = None, | |
| limit: tp.Optional[int] = None, | |
| delay: tp.Optional[float] = None, | |
| retries: tp.Optional[int] = None, | |
| fetch_params: tp.Optional[tp.KwargsLike] = None, | |
| show_progress: tp.Optional[bool] = None, | |
| pbar_kwargs: tp.KwargsLike = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| return_fetch_method: bool = False, | |
| ) -> tp.Union[dict, tp.SymbolData]: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from CCXT. | |
| Args: | |
| symbol (str): Symbol. | |
| Symbol can be in the `EXCHANGE:SYMBOL` format, in this case `exchange` argument will be ignored. | |
| exchange (str or object): Exchange identifier or an exchange object. | |
| See `CCXTData.resolve_exchange`. | |
| exchange_config (dict): Exchange config. | |
| See `CCXTData.resolve_exchange`. | |
| start (any): Start datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): End datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| find_earliest_date (bool): Whether to find the earliest date using `CCXTData.find_earliest_date`. | |
| limit (int): The maximum number of returned items. | |
| delay (float): Time to sleep after each request (in seconds). | |
| !!! note | |
| Use only if `enableRateLimit` is not set. | |
| retries (int): The number of retries on failure to fetch data. | |
| fetch_params (dict): Exchange-specific keyword arguments passed to `fetch_ohlcv`. | |
| show_progress (bool): Whether to show the progress bar. | |
| pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`. | |
| silence_warnings (bool): Whether to silence all warnings. | |
| return_fetch_method (bool): Required by `CCXTData.find_earliest_date`. | |
| For defaults, see `custom.ccxt` in `vectorbtpro._settings.data`. | |
| Global settings can be provided per exchange id using the `exchanges` dictionary. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("ccxt") | |
| import ccxt | |
| exchange = cls.resolve_custom_setting(exchange, "exchange") | |
| if exchange is None and ":" in symbol: | |
| exchange, symbol = symbol.split(":") | |
| if exchange_config is None: | |
| exchange_config = {} | |
| exchange = cls.resolve_exchange(exchange=exchange, **exchange_config) | |
| exchange_name = type(exchange).__name__ | |
| start = cls.resolve_exchange_setting(start, "start", exchange_name=exchange_name) | |
| end = cls.resolve_exchange_setting(end, "end", exchange_name=exchange_name) | |
| timeframe = cls.resolve_exchange_setting(timeframe, "timeframe", exchange_name=exchange_name) | |
| tz = cls.resolve_exchange_setting(tz, "tz", exchange_name=exchange_name) | |
| find_earliest_date = cls.resolve_exchange_setting( | |
| find_earliest_date, "find_earliest_date", exchange_name=exchange_name | |
| ) | |
| limit = cls.resolve_exchange_setting(limit, "limit", exchange_name=exchange_name) | |
| delay = cls.resolve_exchange_setting(delay, "delay", exchange_name=exchange_name) | |
| retries = cls.resolve_exchange_setting(retries, "retries", exchange_name=exchange_name) | |
| fetch_params = cls.resolve_exchange_setting( | |
| fetch_params, "fetch_params", merge=True, exchange_name=exchange_name | |
| ) | |
| show_progress = cls.resolve_exchange_setting(show_progress, "show_progress", exchange_name=exchange_name) | |
| pbar_kwargs = cls.resolve_exchange_setting(pbar_kwargs, "pbar_kwargs", merge=True, exchange_name=exchange_name) | |
| if "bar_id" not in pbar_kwargs: | |
| pbar_kwargs["bar_id"] = "ccxt" | |
| silence_warnings = cls.resolve_exchange_setting( | |
| silence_warnings, "silence_warnings", exchange_name=exchange_name | |
| ) | |
| if not exchange.has["fetchOHLCV"]: | |
| raise ValueError(f"Exchange {exchange} does not support OHLCV") | |
| if exchange.has["fetchOHLCV"] == "emulated": | |
| if not silence_warnings: | |
| warn("Using emulated OHLCV candles") | |
| freq = timeframe | |
| split = dt.split_freq_str(timeframe) | |
| if split is not None: | |
| multiplier, unit = split | |
| if unit == "D": | |
| unit = "d" | |
| elif unit == "W": | |
| unit = "w" | |
| elif unit == "Y": | |
| unit = "y" | |
| timeframe = str(multiplier) + unit | |
| if timeframe not in exchange.timeframes: | |
| raise ValueError(f"Exchange {exchange} does not support {timeframe} timeframe") | |
| def _retry(method): | |
| @wraps(method) | |
| def retry_method(*args, **kwargs): | |
| for i in range(retries): | |
| try: | |
| return method(*args, **kwargs) | |
| except ccxt.NetworkError as e: | |
| if i == retries - 1: | |
| raise e | |
| if not silence_warnings: | |
| warn(traceback.format_exc()) | |
| if delay is not None: | |
| time.sleep(delay) | |
| return retry_method | |
| @_retry | |
| def _fetch(_since, _limit): | |
| return exchange.fetch_ohlcv( | |
| symbol, | |
| timeframe=timeframe, | |
| since=_since, | |
| limit=_limit, | |
| params=fetch_params, | |
| ) | |
| if return_fetch_method: | |
| return dict(fetch_func=_fetch, start=start, end=end, tz=tz) | |
| # Establish the timestamps | |
| if find_earliest_date and start is not None: | |
| start = cls._find_earliest_date(_fetch, start=start, end=end, tz=tz, for_internal_use=True) | |
| if start is not None: | |
| start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) | |
| else: | |
| start_ts = None | |
| if end is not None: | |
| end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="UTC")) | |
| else: | |
| end_ts = None | |
| prev_end_ts = None | |
| def _ts_to_str(ts: tp.Optional[int]) -> str: | |
| if ts is None: | |
| return "?" | |
| return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe) | |
| def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool: | |
| if start_ts is not None: | |
| if d[0] < start_ts: | |
| return False | |
| if _prev_end_ts is not None: | |
| if d[0] <= _prev_end_ts: | |
| return False | |
| if end_ts is not None: | |
| if d[0] >= end_ts: | |
| return False | |
| return True | |
| # Iteratively collect the data | |
| data = [] | |
| try: | |
| with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar: | |
| pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts))) | |
| while True: | |
| # Fetch the klines for the next timeframe | |
| next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit) | |
| next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data)) | |
| # Update the timestamps and the progress bar | |
| if not len(next_data): | |
| break | |
| data += next_data | |
| if start_ts is None: | |
| start_ts = next_data[0][0] | |
| pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0]))) | |
| pbar.update() | |
| prev_end_ts = next_data[-1][0] | |
| if end_ts is not None and prev_end_ts >= end_ts: | |
| break | |
| if delay is not None: | |
| time.sleep(delay) # be kind to api | |
| except Exception as e: | |
| if not silence_warnings: | |
| warn(traceback.format_exc()) | |
| warn( | |
| f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. " | |
| "Use update() method to fetch missing data." | |
| ) | |
| # Convert data to a DataFrame | |
| df = pd.DataFrame(data, columns=["Open time", "Open", "High", "Low", "Close", "Volume"]) | |
| df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True) | |
| del df["Open time"] | |
| if "Open" in df.columns: | |
| df["Open"] = df["Open"].astype(float) | |
| if "High" in df.columns: | |
| df["High"] = df["High"].astype(float) | |
| if "Low" in df.columns: | |
| df["Low"] = df["Low"].astype(float) | |
| if "Close" in df.columns: | |
| df["Close"] = df["Close"].astype(float) | |
| if "Volume" in df.columns: | |
| df["Volume"] = df["Volume"].astype(float) | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/csv.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `CSVData`.""" | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.file import FileData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| __all__ = [ | |
| "CSVData", | |
| ] | |
| __pdoc__ = {} | |
| CSVDataT = tp.TypeVar("CSVDataT", bound="CSVData") | |
| class CSVData(FileData): | |
| """Data class for fetching CSV data.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.csv") | |
| @classmethod | |
| def is_csv_file(cls, path: tp.PathLike) -> bool: | |
| """Return whether the path is a CSV/TSV file.""" | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists() and path.is_file() and ".csv" in path.suffixes: | |
| return True | |
| if path.exists() and path.is_file() and ".tsv" in path.suffixes: | |
| return True | |
| return False | |
| @classmethod | |
| def is_file_match(cls, path: tp.PathLike) -> bool: | |
| return cls.is_csv_file(path) | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| paths: tp.Any = None, | |
| ) -> tp.Kwargs: | |
| keys_meta = FileData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| if keys_meta["keys"] is None and paths is None: | |
| keys_meta["keys"] = cls.list_paths() | |
| return keys_meta | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: tp.Key, | |
| path: tp.Any = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| tz: tp.TimezoneLike = None, | |
| start_row: tp.Optional[int] = None, | |
| end_row: tp.Optional[int] = None, | |
| header: tp.Optional[tp.MaybeSequence[int]] = None, | |
| index_col: tp.Optional[int] = None, | |
| parse_dates: tp.Optional[bool] = None, | |
| chunk_func: tp.Optional[tp.Callable] = None, | |
| squeeze: tp.Optional[bool] = None, | |
| **read_kwargs, | |
| ) -> tp.KeyData: | |
| """Fetch the CSV file of a feature or symbol. | |
| Args: | |
| key (hashable): Feature or symbol. | |
| path (str): Path. | |
| If `path` is None, uses `key` as the path to the CSV file. | |
| start (any): Start datetime. | |
| Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. | |
| end (any): End datetime. | |
| Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. | |
| tz (any): Target timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| start_row (int): Start row (inclusive). | |
| Must exclude header rows. | |
| end_row (int): End row (exclusive). | |
| Must exclude header rows. | |
| header (int or sequence of int): See `pd.read_csv`. | |
| index_col (int): See `pd.read_csv`. | |
| If False, will pass None. | |
| parse_dates (bool): See `pd.read_csv`. | |
| chunk_func (callable): Function to select and concatenate chunks from `TextFileReader`. | |
| Gets called only if `iterator` or `chunksize` are set. | |
| squeeze (int): Whether to squeeze a DataFrame with one column into a Series. | |
| **read_kwargs: Other keyword arguments passed to `pd.read_csv`. | |
| `skiprows` and `nrows` will be automatically calculated based on `start_row` and `end_row`. | |
| When either `start` or `end` is provided, will fetch the entire data first and filter it thereafter. | |
| See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for other arguments. | |
| For defaults, see `custom.csv` in `vectorbtpro._settings.data`.""" | |
| from pandas.io.parsers import TextFileReader | |
| from pandas.api.types import is_object_dtype | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| start_row = cls.resolve_custom_setting(start_row, "start_row") | |
| if start_row is None: | |
| start_row = 0 | |
| end_row = cls.resolve_custom_setting(end_row, "end_row") | |
| header = cls.resolve_custom_setting(header, "header") | |
| index_col = cls.resolve_custom_setting(index_col, "index_col") | |
| if index_col is False: | |
| index_col = None | |
| parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates") | |
| chunk_func = cls.resolve_custom_setting(chunk_func, "chunk_func") | |
| squeeze = cls.resolve_custom_setting(squeeze, "squeeze") | |
| read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) | |
| if path is None: | |
| path = key | |
| if isinstance(header, int): | |
| header = [header] | |
| header_rows = header[-1] + 1 | |
| start_row += header_rows | |
| if end_row is not None: | |
| end_row += header_rows | |
| skiprows = range(header_rows, start_row) | |
| if end_row is not None: | |
| nrows = end_row - start_row | |
| else: | |
| nrows = None | |
| sep = read_kwargs.pop("sep", None) | |
| if isinstance(path, (str, Path)): | |
| try: | |
| _path = Path(path) | |
| if _path.suffix.lower() == ".csv": | |
| if sep is None: | |
| sep = "," | |
| if _path.suffix.lower() == ".tsv": | |
| if sep is None: | |
| sep = "\t" | |
| except Exception as e: | |
| pass | |
| if sep is None: | |
| sep = "," | |
| obj = pd.read_csv( | |
| path, | |
| sep=sep, | |
| header=header, | |
| index_col=index_col, | |
| parse_dates=parse_dates, | |
| skiprows=skiprows, | |
| nrows=nrows, | |
| **read_kwargs, | |
| ) | |
| if isinstance(obj, TextFileReader): | |
| if chunk_func is None: | |
| obj = pd.concat(list(obj), axis=0) | |
| else: | |
| obj = chunk_func(obj) | |
| if isinstance(obj, pd.DataFrame) and squeeze: | |
| obj = obj.squeeze("columns") | |
| if isinstance(obj, pd.Series) and obj.name == "0": | |
| obj.name = None | |
| if index_col is not None and parse_dates and is_object_dtype(obj.index.dtype): | |
| obj.index = pd.to_datetime(obj.index, utc=True) | |
| if tz is not None: | |
| obj.index = obj.index.tz_convert(tz) | |
| if isinstance(obj.index, pd.DatetimeIndex) and tz is None: | |
| tz = obj.index.tz | |
| if start is not None or end is not None: | |
| if not isinstance(obj.index, pd.DatetimeIndex): | |
| raise TypeError("Cannot filter index that is not DatetimeIndex") | |
| if obj.index.tz is not None: | |
| if start is not None: | |
| start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=obj.index.tz) | |
| if end is not None: | |
| end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=obj.index.tz) | |
| else: | |
| if start is not None: | |
| start = dt.to_naive_timestamp(start, tz=tz) | |
| if end is not None: | |
| end = dt.to_naive_timestamp(end, tz=tz) | |
| mask = True | |
| if start is not None: | |
| mask &= obj.index >= start | |
| if end is not None: | |
| mask &= obj.index < end | |
| mask_indices = np.flatnonzero(mask) | |
| if len(mask_indices) == 0: | |
| return None | |
| obj = obj.iloc[mask_indices[0] : mask_indices[-1] + 1] | |
| start_row += mask_indices[0] | |
| return obj, dict(last_row=start_row - header_rows + len(obj.index) - 1, tz=tz) | |
| @classmethod | |
| def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Fetch the CSV file of a feature. | |
| Uses `CSVData.fetch_key`.""" | |
| return cls.fetch_key(feature, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Fetch the CSV file of a symbol. | |
| Uses `CSVData.fetch_key`.""" | |
| return cls.fetch_key(symbol, **kwargs) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| returned_kwargs = self.select_returned_kwargs(key) | |
| fetch_kwargs["start_row"] = returned_kwargs["last_row"] | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `CSVData.update_key` with `key_is_feature=True`.""" | |
| return self.update_key(feature, key_is_feature=True, **kwargs) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `CSVData.update_key` with `key_is_feature=False`.""" | |
| return self.update_key(symbol, key_is_feature=False, **kwargs) | |
| </file> | |
| <file path="data/custom/custom.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `CustomData`.""" | |
| import fnmatch | |
| import re | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.base import Data | |
| __all__ = [ | |
| "CustomData", | |
| ] | |
| __pdoc__ = {} | |
| class CustomData(Data): | |
| """Data class for fetching custom data.""" | |
| _settings_path: tp.SettingsPath = dict(custom=None) | |
| @classmethod | |
| def get_custom_settings(cls, *args, **kwargs) -> dict: | |
| """`CustomData.get_settings` with `path_id="custom"`.""" | |
| return cls.get_settings(*args, path_id="custom", **kwargs) | |
| @classmethod | |
| def has_custom_settings(cls, *args, **kwargs) -> bool: | |
| """`CustomData.has_settings` with `path_id="custom"`.""" | |
| return cls.has_settings(*args, path_id="custom", **kwargs) | |
| @classmethod | |
| def get_custom_setting(cls, *args, **kwargs) -> tp.Any: | |
| """`CustomData.get_setting` with `path_id="custom"`.""" | |
| return cls.get_setting(*args, path_id="custom", **kwargs) | |
| @classmethod | |
| def has_custom_setting(cls, *args, **kwargs) -> bool: | |
| """`CustomData.has_setting` with `path_id="custom"`.""" | |
| return cls.has_setting(*args, path_id="custom", **kwargs) | |
| @classmethod | |
| def resolve_custom_setting(cls, *args, **kwargs) -> tp.Any: | |
| """`CustomData.resolve_setting` with `path_id="custom"`.""" | |
| return cls.resolve_setting(*args, path_id="custom", **kwargs) | |
| @classmethod | |
| def set_custom_settings(cls, *args, **kwargs) -> None: | |
| """`CustomData.set_settings` with `path_id="custom"`.""" | |
| cls.set_settings(*args, path_id="custom", **kwargs) | |
| @staticmethod | |
| def key_match(key: str, pattern: str, use_regex: bool = False): | |
| """Return whether key matches pattern. | |
| If `use_regex` is True, checks against a regular expression. | |
| Otherwise, checks against a glob-style pattern.""" | |
| if use_regex: | |
| return re.match(pattern, key) | |
| return re.match(fnmatch.translate(pattern), key) | |
| </file> | |
| <file path="data/custom/db.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `DBData`.""" | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.local import LocalData | |
| __all__ = [ | |
| "DBData", | |
| ] | |
| __pdoc__ = {} | |
| class DBData(LocalData): | |
| """Data class for fetching database data.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.db") | |
| </file> | |
| <file path="data/custom/duckdb.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `DuckDBData`.""" | |
| from pathlib import Path | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.base import key_dict | |
| from vectorbtpro.data.custom.db import DBData | |
| from vectorbtpro.data.custom.file import FileData | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from duckdb import DuckDBPyConnection as DuckDBPyConnectionT, DuckDBPyRelation as DuckDBPyRelationT | |
| except ImportError: | |
| DuckDBPyConnectionT = "DuckDBPyConnection" | |
| DuckDBPyRelationT = "DuckDBPyRelation" | |
| __all__ = [ | |
| "DuckDBData", | |
| ] | |
| __pdoc__ = {} | |
| DuckDBDataT = tp.TypeVar("DuckDBDataT", bound="DuckDBData") | |
| class DuckDBData(DBData): | |
| """Data class for fetching data using DuckDB. | |
| See `DuckDBData.pull` and `DuckDBData.fetch_key` for arguments. | |
| Usage: | |
| * Set up the connection settings globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.DuckDBData.set_custom_settings(connection="database.duckdb") | |
| ``` | |
| * Pull tables: | |
| ```pycon | |
| >>> data = vbt.DuckDBData.pull(["TABLE1", "TABLE2"]) | |
| ``` | |
| * Rename tables: | |
| ```pycon | |
| >>> data = vbt.DuckDBData.pull( | |
| ... ["SYMBOL1", "SYMBOL2"], | |
| ... table=vbt.key_dict({ | |
| ... "SYMBOL1": "TABLE1", | |
| ... "SYMBOL2": "TABLE2" | |
| ... }) | |
| ... ) | |
| ``` | |
| * Pull queries: | |
| ```pycon | |
| >>> data = vbt.DuckDBData.pull( | |
| ... ["SYMBOL1", "SYMBOL2"], | |
| ... query=vbt.key_dict({ | |
| ... "SYMBOL1": "SELECT * FROM TABLE1", | |
| ... "SYMBOL2": "SELECT * FROM TABLE2" | |
| ... }) | |
| ... ) | |
| ``` | |
| * Pull Parquet files: | |
| ```pycon | |
| >>> data = vbt.DuckDBData.pull( | |
| ... ["SYMBOL1", "SYMBOL2"], | |
| ... read_path=vbt.key_dict({ | |
| ... "SYMBOL1": "s1.parquet", | |
| ... "SYMBOL2": "s2.parquet" | |
| ... }) | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.duckdb") | |
| @classmethod | |
| def resolve_connection( | |
| cls, | |
| connection: tp.Union[None, str, tp.PathLike, DuckDBPyConnectionT] = None, | |
| read_only: bool = True, | |
| return_meta: bool = False, | |
| **connection_config, | |
| ) -> tp.Union[DuckDBPyConnectionT, dict]: | |
| """Resolve the connection.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("duckdb") | |
| from duckdb import connect, default_connection | |
| connection_meta = {} | |
| connection = cls.resolve_custom_setting(connection, "connection") | |
| if connection_config is None: | |
| connection_config = {} | |
| has_connection_config = len(connection_config) > 0 | |
| connection_config["read_only"] = read_only | |
| connection_config = cls.resolve_custom_setting(connection_config, "connection_config", merge=True) | |
| read_only = connection_config.pop("read_only", read_only) | |
| should_close = False | |
| if connection is None: | |
| if len(connection_config) == 0: | |
| connection = default_connection | |
| else: | |
| database = connection_config.pop("database", None) | |
| if "config" in connection_config or len(connection_config) == 0: | |
| connection = connect(database, read_only=read_only, **connection_config) | |
| else: | |
| connection = connect(database, read_only=read_only, config=connection_config) | |
| should_close = True | |
| elif isinstance(connection, (str, Path)): | |
| if "config" in connection_config or len(connection_config) == 0: | |
| connection = connect(str(connection), read_only=read_only, **connection_config) | |
| else: | |
| connection = connect(str(connection), read_only=read_only, config=connection_config) | |
| should_close = True | |
| elif has_connection_config: | |
| raise ValueError("Cannot apply connection_config to already initialized connection") | |
| if return_meta: | |
| connection_meta["connection"] = connection | |
| connection_meta["should_close"] = should_close | |
| return connection_meta | |
| return connection | |
| @classmethod | |
| def list_catalogs( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| incl_system: bool = False, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| ) -> tp.List[str]: | |
| """List all catalogs. | |
| Catalogs "system" and "temp" are skipped if `incl_system` is False. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" | |
| if connection_config is None: | |
| connection_config = {} | |
| connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) | |
| connection = connection_meta["connection"] | |
| should_close = connection_meta["should_close"] | |
| schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df() | |
| catalogs = [] | |
| for catalog in schemata_df["catalog_name"].tolist(): | |
| if pattern is not None: | |
| if not cls.key_match(catalog, pattern, use_regex=use_regex): | |
| continue | |
| if not incl_system and catalog == "system": | |
| continue | |
| if not incl_system and catalog == "temp": | |
| continue | |
| catalogs.append(catalog) | |
| if should_close: | |
| connection.close() | |
| if sort: | |
| return sorted(dict.fromkeys(catalogs)) | |
| return list(dict.fromkeys(catalogs)) | |
| @classmethod | |
| def list_schemas( | |
| cls, | |
| catalog_pattern: tp.Optional[str] = None, | |
| schema_pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| catalog: tp.Optional[str] = None, | |
| incl_system: bool = False, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| ) -> tp.List[str]: | |
| """List all schemas. | |
| If `catalog` is None, searches for all catalog names in the database and prefixes each schema | |
| with the respective catalog name. If `catalog` is provided, returns the schemas corresponding | |
| to this catalog without a prefix. Schemas "information_schema" and "pg_catalog" are skipped | |
| if `incl_system` is False. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" | |
| if connection_config is None: | |
| connection_config = {} | |
| connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) | |
| connection = connection_meta["connection"] | |
| should_close = connection_meta["should_close"] | |
| if catalog is None: | |
| catalogs = cls.list_catalogs( | |
| pattern=catalog_pattern, | |
| use_regex=use_regex, | |
| sort=sort, | |
| incl_system=incl_system, | |
| connection=connection, | |
| connection_config=connection_config, | |
| ) | |
| if len(catalogs) == 1: | |
| prefix_catalog = False | |
| else: | |
| prefix_catalog = True | |
| else: | |
| catalogs = [catalog] | |
| prefix_catalog = False | |
| schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df() | |
| schemas = [] | |
| for catalog in catalogs: | |
| all_schemas = schemata_df[schemata_df["catalog_name"] == catalog]["schema_name"].tolist() | |
| for schema in all_schemas: | |
| if schema_pattern is not None: | |
| if not cls.key_match(schema, schema_pattern, use_regex=use_regex): | |
| continue | |
| if not incl_system and schema == "information_schema": | |
| continue | |
| if not incl_system and schema == "pg_catalog": | |
| continue | |
| if prefix_catalog: | |
| schema = catalog + ":" + schema | |
| schemas.append(schema) | |
| if should_close: | |
| connection.close() | |
| if sort: | |
| return sorted(dict.fromkeys(schemas)) | |
| return list(dict.fromkeys(schemas)) | |
| @classmethod | |
| def get_current_schema( | |
| cls, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| ) -> str: | |
| """Get the current schema.""" | |
| if connection_config is None: | |
| connection_config = {} | |
| connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) | |
| connection = connection_meta["connection"] | |
| should_close = connection_meta["should_close"] | |
| current_schema = connection.sql("SELECT current_schema()").fetchall()[0][0] | |
| if should_close: | |
| connection.close() | |
| return current_schema | |
| @classmethod | |
| def list_tables( | |
| cls, | |
| *, | |
| catalog_pattern: tp.Optional[str] = None, | |
| schema_pattern: tp.Optional[str] = None, | |
| table_pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| catalog: tp.Optional[str] = None, | |
| schema: tp.Optional[str] = None, | |
| incl_system: bool = False, | |
| incl_temporary: bool = False, | |
| incl_views: bool = True, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| ) -> tp.List[str]: | |
| """List all tables and views. | |
| If `schema` is None, searches for all schema names in the database and prefixes each table | |
| with the respective catalog and schema name (unless there's only one schema which is the current | |
| schema or `schema` is `current_schema`). If `schema` is provided, returns the tables corresponding | |
| to this schema without a prefix. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against | |
| `schema_pattern` and each table against `table_pattern`.""" | |
| if connection_config is None: | |
| connection_config = {} | |
| connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) | |
| connection = connection_meta["connection"] | |
| should_close = connection_meta["should_close"] | |
| if catalog is None: | |
| catalogs = cls.list_catalogs( | |
| pattern=catalog_pattern, | |
| use_regex=use_regex, | |
| sort=sort, | |
| incl_system=incl_system, | |
| connection=connection, | |
| connection_config=connection_config, | |
| ) | |
| if catalog_pattern is None and len(catalogs) == 1: | |
| prefix_catalog = False | |
| else: | |
| prefix_catalog = True | |
| else: | |
| catalogs = [catalog] | |
| prefix_catalog = False | |
| current_schema = cls.get_current_schema( | |
| connection=connection, | |
| connection_config=connection_config, | |
| ) | |
| if schema is None: | |
| catalogs_schemas = [] | |
| for catalog in catalogs: | |
| catalog_schemas = cls.list_schemas( | |
| schema_pattern=schema_pattern, | |
| use_regex=use_regex, | |
| sort=sort, | |
| catalog=catalog, | |
| incl_system=incl_system, | |
| connection=connection, | |
| connection_config=connection_config, | |
| ) | |
| for schema in catalog_schemas: | |
| catalogs_schemas.append((catalog, schema)) | |
| if len(catalogs_schemas) == 1 and catalogs_schemas[0][1] == current_schema: | |
| prefix_schema = False | |
| else: | |
| prefix_schema = True | |
| else: | |
| if schema == "current_schema": | |
| schema = current_schema | |
| catalogs_schemas = [] | |
| for catalog in catalogs: | |
| catalogs_schemas.append((catalog, schema)) | |
| prefix_schema = prefix_catalog | |
| tables_df = connection.sql("SELECT * FROM information_schema.tables").df() | |
| tables = [] | |
| for catalog, schema in catalogs_schemas: | |
| all_tables = [] | |
| all_tables.extend( | |
| tables_df[ | |
| (tables_df["table_catalog"] == catalog) | |
| & (tables_df["table_schema"] == schema) | |
| & (tables_df["table_type"] == "BASE TABLE") | |
| ]["table_name"].tolist() | |
| ) | |
| if incl_temporary: | |
| all_tables.extend( | |
| tables_df[ | |
| (tables_df["table_catalog"] == catalog) | |
| & (tables_df["table_schema"] == schema) | |
| & (tables_df["table_type"] == "LOCAL TEMPORARY") | |
| ]["table_name"].tolist() | |
| ) | |
| if incl_views: | |
| all_tables.extend( | |
| tables_df[ | |
| (tables_df["table_catalog"] == catalog) | |
| & (tables_df["table_schema"] == schema) | |
| & (tables_df["table_type"] == "VIEW") | |
| ]["table_name"].tolist() | |
| ) | |
| for table in all_tables: | |
| if table_pattern is not None: | |
| if not cls.key_match(table, table_pattern, use_regex=use_regex): | |
| continue | |
| if not prefix_catalog and prefix_schema: | |
| table = schema + ":" + table | |
| elif prefix_catalog or prefix_schema: | |
| table = catalog + ":" + schema + ":" + table | |
| tables.append(table) | |
| if should_close: | |
| connection.close() | |
| if sort: | |
| return sorted(dict.fromkeys(tables)) | |
| return list(dict.fromkeys(tables)) | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| catalog: tp.Optional[str] = None, | |
| schema: tp.Optional[str] = None, | |
| list_tables_kwargs: tp.KwargsLike = None, | |
| read_path: tp.Optional[tp.PathLike] = None, | |
| read_format: tp.Optional[str] = None, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| ) -> tp.Kwargs: | |
| keys_meta = DBData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| if keys_meta["keys"] is None: | |
| if cls.has_key_dict(catalog): | |
| raise ValueError("Cannot populate keys if catalog is defined per key") | |
| if cls.has_key_dict(schema): | |
| raise ValueError("Cannot populate keys if schema is defined per key") | |
| if cls.has_key_dict(list_tables_kwargs): | |
| raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key") | |
| if cls.has_key_dict(connection): | |
| raise ValueError("Cannot populate keys if connection is defined per key") | |
| if cls.has_key_dict(connection_config): | |
| raise ValueError("Cannot populate keys if connection_config is defined per key") | |
| if cls.has_key_dict(read_path): | |
| raise ValueError("Cannot populate keys if read_path is defined per key") | |
| if cls.has_key_dict(read_format): | |
| raise ValueError("Cannot populate keys if read_format is defined per key") | |
| if read_path is not None or read_format is not None: | |
| if read_path is None: | |
| read_path = "." | |
| if read_format is not None: | |
| read_format = read_format.lower() | |
| checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format") | |
| keys_meta["keys"] = FileData.list_paths(read_path, extension=read_format) | |
| else: | |
| if list_tables_kwargs is None: | |
| list_tables_kwargs = {} | |
| keys_meta["keys"] = cls.list_tables( | |
| catalog=catalog, | |
| schema=schema, | |
| connection=connection, | |
| connection_config=connection_config, | |
| **list_tables_kwargs, | |
| ) | |
| return keys_meta | |
| @classmethod | |
| def pull( | |
| cls: tp.Type[DuckDBDataT], | |
| keys: tp.Union[tp.MaybeKeys] = None, | |
| *, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[tp.MaybeFeatures] = None, | |
| symbols: tp.Union[tp.MaybeSymbols] = None, | |
| catalog: tp.Optional[str] = None, | |
| schema: tp.Optional[str] = None, | |
| list_tables_kwargs: tp.KwargsLike = None, | |
| read_path: tp.Optional[tp.PathLike] = None, | |
| read_format: tp.Optional[str] = None, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| share_connection: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> DuckDBDataT: | |
| """Override `vectorbtpro.data.base.Data.pull` to resolve and share the connection among the keys | |
| and use the table names available in the database in case no keys were provided.""" | |
| if share_connection is None: | |
| if not cls.has_key_dict(connection) and not cls.has_key_dict(connection_config): | |
| share_connection = True | |
| else: | |
| share_connection = False | |
| if share_connection: | |
| if connection_config is None: | |
| connection_config = {} | |
| connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) | |
| connection = connection_meta["connection"] | |
| should_close = connection_meta["should_close"] | |
| else: | |
| should_close = False | |
| keys_meta = cls.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| catalog=catalog, | |
| schema=schema, | |
| list_tables_kwargs=list_tables_kwargs, | |
| read_path=read_path, | |
| read_format=read_format, | |
| connection=connection, | |
| connection_config=connection_config, | |
| ) | |
| keys = keys_meta["keys"] | |
| if isinstance(read_path, key_dict): | |
| new_read_path = read_path.copy() | |
| else: | |
| new_read_path = key_dict() | |
| if isinstance(keys, dict): | |
| new_keys = {} | |
| for k, v in keys.items(): | |
| if isinstance(k, Path): | |
| new_k = FileData.path_to_key(k) | |
| new_read_path[new_k] = k | |
| k = new_k | |
| new_keys[k] = v | |
| keys = new_keys | |
| elif cls.has_multiple_keys(keys): | |
| new_keys = [] | |
| for k in keys: | |
| if isinstance(k, Path): | |
| new_k = FileData.path_to_key(k) | |
| new_read_path[new_k] = k | |
| k = new_k | |
| new_keys.append(k) | |
| keys = new_keys | |
| else: | |
| if isinstance(keys, Path): | |
| new_keys = FileData.path_to_key(keys) | |
| new_read_path[new_keys] = keys | |
| keys = new_keys | |
| if len(new_read_path) > 0: | |
| read_path = new_read_path | |
| keys_are_features = keys_meta["keys_are_features"] | |
| outputs = super(DBData, cls).pull( | |
| keys, | |
| keys_are_features=keys_are_features, | |
| catalog=catalog, | |
| schema=schema, | |
| read_path=read_path, | |
| read_format=read_format, | |
| connection=connection, | |
| connection_config=connection_config, | |
| **kwargs, | |
| ) | |
| if should_close: | |
| connection.close() | |
| return outputs | |
| @classmethod | |
| def format_write_option(cls, option: tp.Any) -> str: | |
| """Format a write option.""" | |
| if isinstance(option, str): | |
| return f"'{option}'" | |
| if isinstance(option, (tuple, list)): | |
| return "(" + ", ".join(map(str, option)) + ")" | |
| if isinstance(option, dict): | |
| return "{" + ", ".join(map(lambda y: f"{y[0]}: {cls.format_write_option(y[1])}", option.items())) + "}" | |
| return f"{option}" | |
| @classmethod | |
| def format_write_options(cls, options: tp.Union[str, dict]) -> str: | |
| """Format write options.""" | |
| if isinstance(options, str): | |
| return options | |
| new_options = [] | |
| for k, v in options.items(): | |
| new_options.append(f"{k.upper()} {cls.format_write_option(v)}") | |
| return ", ".join(new_options) | |
| @classmethod | |
| def format_read_option(cls, option: tp.Any) -> str: | |
| """Format a read option.""" | |
| if isinstance(option, str): | |
| return f"'{option}'" | |
| if isinstance(option, (tuple, list)): | |
| return "[" + ", ".join(map(cls.format_read_option, option)) + "]" | |
| if isinstance(option, dict): | |
| return "{" + ", ".join(map(lambda y: f"'{y[0]}': {cls.format_read_option(y[1])}", option.items())) + "}" | |
| return f"{option}" | |
| @classmethod | |
| def format_read_options(cls, options: tp.Union[str, dict]) -> str: | |
| """Format read options.""" | |
| if isinstance(options, str): | |
| return options | |
| new_options = [] | |
| for k, v in options.items(): | |
| new_options.append(f"{k.lower()}={cls.format_read_option(v)}") | |
| return ", ".join(new_options) | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: str, | |
| table: tp.Optional[str] = None, | |
| schema: tp.Optional[str] = None, | |
| catalog: tp.Optional[str] = None, | |
| read_path: tp.Optional[tp.PathLike] = None, | |
| read_format: tp.Optional[str] = None, | |
| read_options: tp.Union[None, str, dict] = None, | |
| query: tp.Union[None, str, DuckDBPyRelationT] = None, | |
| connection: tp.Union[None, str, DuckDBPyConnectionT] = None, | |
| connection_config: tp.KwargsLike = None, | |
| start: tp.Optional[tp.Any] = None, | |
| end: tp.Optional[tp.Any] = None, | |
| align_dates: tp.Optional[bool] = None, | |
| parse_dates: tp.Union[None, bool, tp.Sequence[str]] = None, | |
| to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None, | |
| tz: tp.TimezoneLike = None, | |
| index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None, | |
| squeeze: tp.Optional[bool] = None, | |
| df_kwargs: tp.KwargsLike = None, | |
| **sql_kwargs, | |
| ) -> tp.KeyData: | |
| """Fetch a feature or symbol from a DuckDB database. | |
| Can use a table name (which defaults to the key) or a custom query. | |
| Args: | |
| key (str): Feature or symbol. | |
| If `table` and `query` are both None, becomes the table name. | |
| Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored. | |
| table (str): Table name. | |
| Cannot be used together with `file` or `query`. | |
| schema (str): Schema name. | |
| Cannot be used together with `file` or `query`. | |
| catalog (str): Catalog name. | |
| Cannot be used together with ``file` or query`. | |
| read_path (path_like): Path to a file to read. | |
| Cannot be used together with `table`, `schema`, `catalog`, or `query`. | |
| read_format (str): Format of the file to read. | |
| Allowed values are "csv", "parquet", and "json". | |
| Requires `read_path` to be set. | |
| read_options (str or dict): Options used to read the file. | |
| Requires `read_path` and `read_format` to be set. | |
| Uses `DuckDBData.format_read_options` to transform a dictionary to a string. | |
| query (str or DuckDBPyRelation): Custom query. | |
| Cannot be used together with `catalog`, `schema`, and `table`. | |
| connection (str or object): See `DuckDBData.resolve_connection`. | |
| connection_config (dict): See `DuckDBData.resolve_connection`. | |
| start (any): Start datetime (if datetime index) or any other start value. | |
| Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True | |
| and the index is a datetime index. Otherwise, you must ensure the correct type is provided. | |
| Cannot be used together with `query`. Include the condition into the query. | |
| end (any): End datetime (if datetime index) or any other end value. | |
| Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True | |
| and the index is a datetime index. Otherwise, you must ensure the correct type is provided. | |
| Cannot be used together with `query`. Include the condition into the query. | |
| align_dates (bool): Whether to align `start` and `end` to the timezone of the index. | |
| Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index. | |
| parse_dates (bool or sequence of str): See `DuckDBData.prepare_dt`. | |
| to_utc (bool, str, or sequence of str): See `DuckDBData.prepare_dt`. | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| index_col (int, str, or list): One or more columns that should become the index. | |
| squeeze (int): Whether to squeeze a DataFrame with one column into a Series. | |
| df_kwargs (dict): Keyword arguments passed to `relation.df` to convert a relation to a DataFrame. | |
| **sql_kwargs: Other keyword arguments passed to `connection.execute` to run a SQL query. | |
| For defaults, see `custom.duckdb` in `vectorbtpro._settings.data`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("duckdb") | |
| from duckdb import DuckDBPyRelation | |
| if connection_config is None: | |
| connection_config = {} | |
| connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) | |
| connection = connection_meta["connection"] | |
| should_close = connection_meta["should_close"] | |
| if catalog is not None and query is not None: | |
| raise ValueError("Cannot use catalog and query together") | |
| if schema is not None and query is not None: | |
| raise ValueError("Cannot use schema and query together") | |
| if table is not None and query is not None: | |
| raise ValueError("Cannot use table and query together") | |
| if read_path is not None and query is not None: | |
| raise ValueError("Cannot use read_path and query together") | |
| if read_path is not None and (catalog is not None or schema is not None or table is not None): | |
| raise ValueError("Cannot use read_path and catalog/schema/table together") | |
| if table is None and read_path is None and read_format is None and query is None: | |
| if ":" in key: | |
| key_parts = key.split(":") | |
| if len(key_parts) == 2: | |
| schema, table = key_parts | |
| else: | |
| catalog, schema, table = key_parts | |
| else: | |
| table = key | |
| if read_format is not None: | |
| read_format = read_format.lower() | |
| checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format") | |
| if read_path is None: | |
| read_path = (Path(".") / key).with_suffix("." + read_format) | |
| else: | |
| if read_path is not None: | |
| if isinstance(read_path, str): | |
| read_path = Path(read_path) | |
| if read_path.suffix[1:] in ["csv", "parquet", "json"]: | |
| read_format = read_path.suffix[1:] | |
| if read_path is not None: | |
| if isinstance(read_path, Path): | |
| read_path = str(read_path) | |
| read_path = cls.format_read_option(read_path) | |
| if read_options is not None: | |
| if read_format is None: | |
| raise ValueError("Must provide read_format for read_options") | |
| read_options = cls.format_read_options(read_options) | |
| catalog = cls.resolve_custom_setting(catalog, "catalog") | |
| schema = cls.resolve_custom_setting(schema, "schema") | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| align_dates = cls.resolve_custom_setting(align_dates, "align_dates") | |
| parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates") | |
| to_utc = cls.resolve_custom_setting(to_utc, "to_utc") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| index_col = cls.resolve_custom_setting(index_col, "index_col") | |
| squeeze = cls.resolve_custom_setting(squeeze, "squeeze") | |
| df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True) | |
| sql_kwargs = cls.resolve_custom_setting(sql_kwargs, "sql_kwargs", merge=True) | |
| if query is None: | |
| if read_path is not None: | |
| if read_options is not None: | |
| query = f"SELECT * FROM read_{read_format}({read_path}, {read_options})" | |
| elif read_format is not None: | |
| query = f"SELECT * FROM read_{read_format}({read_path})" | |
| else: | |
| query = f"SELECT * FROM {read_path}" | |
| else: | |
| if catalog is not None: | |
| if schema is None: | |
| schema = cls.get_current_schema( | |
| connection=connection, | |
| connection_config=connection_config, | |
| ) | |
| query = f'SELECT * FROM "{catalog}"."{schema}"."{table}"' | |
| elif schema is not None: | |
| query = f'SELECT * FROM "{schema}"."{table}"' | |
| else: | |
| query = f'SELECT * FROM "{table}"' | |
| if start is not None or end is not None: | |
| if index_col is None: | |
| raise ValueError("Must provide index column for filtering by start and end") | |
| if not checks.is_int(index_col) and not isinstance(index_col, str): | |
| raise ValueError("Index column must be integer or string for filtering by start and end") | |
| if checks.is_int(index_col) or align_dates: | |
| metadata_df = connection.sql("DESCRIBE " + query + " LIMIT 1").df() | |
| else: | |
| metadata_df = None | |
| if checks.is_int(index_col): | |
| index_name = metadata_df["column_name"].tolist()[0] | |
| else: | |
| index_name = index_col | |
| if parse_dates: | |
| index_column_type = metadata_df[metadata_df["column_name"] == index_name]["column_type"].item() | |
| if index_column_type in ( | |
| "TIMESTAMP_NS", | |
| "TIMESTAMP_MS", | |
| "TIMESTAMP_S", | |
| "TIMESTAMP", | |
| "DATETIME", | |
| ): | |
| if start is not None: | |
| if ( | |
| to_utc is True | |
| or (isinstance(to_utc, str) and to_utc.lower() == "index") | |
| or (checks.is_sequence(to_utc) and index_name in to_utc) | |
| ): | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") | |
| start = dt.to_naive_datetime(start) | |
| else: | |
| start = dt.to_naive_datetime(start, tz=tz) | |
| if end is not None: | |
| if ( | |
| to_utc is True | |
| or (isinstance(to_utc, str) and to_utc.lower() == "index") | |
| or (checks.is_sequence(to_utc) and index_name in to_utc) | |
| ): | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") | |
| end = dt.to_naive_datetime(end) | |
| else: | |
| end = dt.to_naive_datetime(end, tz=tz) | |
| elif index_column_type in ("TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"): | |
| if start is not None: | |
| if ( | |
| to_utc is True | |
| or (isinstance(to_utc, str) and to_utc.lower() == "index") | |
| or (checks.is_sequence(to_utc) and index_name in to_utc) | |
| ): | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") | |
| else: | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz) | |
| if end is not None: | |
| if ( | |
| to_utc is True | |
| or (isinstance(to_utc, str) and to_utc.lower() == "index") | |
| or (checks.is_sequence(to_utc) and index_name in to_utc) | |
| ): | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") | |
| else: | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz) | |
| if start is not None and end is not None: | |
| query += f' WHERE "{index_name}" >= $start AND "{index_name}" < $end' | |
| elif start is not None: | |
| query += f' WHERE "{index_name}" >= $start' | |
| elif end is not None: | |
| query += f' WHERE "{index_name}" < $end' | |
| params = sql_kwargs.get("params", None) | |
| if params is None: | |
| params = {} | |
| else: | |
| params = dict(params) | |
| if not isinstance(params, dict): | |
| raise ValueError("Parameters must be a dictionary for filtering by start and end") | |
| if start is not None: | |
| if "start" in params: | |
| raise ValueError("Start is already in params") | |
| params["start"] = start | |
| if end is not None: | |
| if "end" in params: | |
| raise ValueError("End is already in params") | |
| params["end"] = end | |
| sql_kwargs["params"] = params | |
| else: | |
| if start is not None: | |
| raise ValueError("Start cannot be applied to custom queries") | |
| if end is not None: | |
| raise ValueError("End cannot be applied to custom queries") | |
| if not isinstance(query, DuckDBPyRelation): | |
| relation = connection.sql(query, **sql_kwargs) | |
| else: | |
| relation = query | |
| obj = relation.df(**df_kwargs) | |
| if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index): | |
| if index_col is not None: | |
| if checks.is_int(index_col): | |
| keys = obj.columns[index_col] | |
| elif isinstance(index_col, str): | |
| keys = index_col | |
| else: | |
| keys = [] | |
| for col in index_col: | |
| if checks.is_int(col): | |
| keys.append(obj.columns[col]) | |
| else: | |
| keys.append(col) | |
| obj = obj.set_index(keys) | |
| if not isinstance(obj.index, pd.MultiIndex): | |
| if obj.index.name == "index": | |
| obj.index.name = None | |
| obj = cls.prepare_dt(obj, to_utc=to_utc, parse_dates=parse_dates) | |
| if not isinstance(obj.index, pd.MultiIndex): | |
| if obj.index.name == "index": | |
| obj.index.name = None | |
| if isinstance(obj.index, pd.DatetimeIndex) and tz is None: | |
| tz = obj.index.tz | |
| if isinstance(obj, pd.DataFrame) and squeeze: | |
| obj = obj.squeeze("columns") | |
| if isinstance(obj, pd.Series) and obj.name == "0": | |
| obj.name = None | |
| if should_close: | |
| connection.close() | |
| return obj, dict(tz=tz) | |
| @classmethod | |
| def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData: | |
| """Fetch the table of a feature. | |
| Uses `DuckDBData.fetch_key`.""" | |
| return cls.fetch_key(feature, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData: | |
| """Fetch the table for a symbol. | |
| Uses `DuckDBData.fetch_key`.""" | |
| return cls.fetch_key(symbol, **kwargs) | |
| def update_key(self, key: str, from_last_index: tp.Optional[bool] = None, **kwargs) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| pre_kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if from_last_index is None: | |
| if pre_kwargs.get("query", None) is not None: | |
| from_last_index = False | |
| else: | |
| from_last_index = True | |
| if from_last_index: | |
| fetch_kwargs["start"] = self.select_last_index(key) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if self.feature_oriented: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: str, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `DuckDBData.update_key`.""" | |
| return self.update_key(feature, **kwargs) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `DuckDBData.update_key`.""" | |
| return self.update_key(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/feather.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `FeatherData`.""" | |
| from pathlib import Path | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.file import FileData | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.config import merge_dicts | |
| __all__ = [ | |
| "FeatherData", | |
| ] | |
| __pdoc__ = {} | |
| FeatherDataT = tp.TypeVar("FeatherDataT", bound="FeatherData") | |
| class FeatherData(FileData): | |
| """Data class for fetching Feather data using PyArrow.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.feather") | |
| @classmethod | |
| def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]: | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists() and path.is_dir(): | |
| path = path / "*.feather" | |
| return cls.match_path(path, **match_path_kwargs) | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| paths: tp.Any = None, | |
| ) -> tp.Kwargs: | |
| keys_meta = FileData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| if keys_meta["keys"] is None and paths is None: | |
| keys_meta["keys"] = "*.feather" | |
| return keys_meta | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: tp.Key, | |
| path: tp.Any = None, | |
| tz: tp.TimezoneLike = None, | |
| index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None, | |
| squeeze: tp.Optional[bool] = None, | |
| **read_kwargs, | |
| ) -> tp.KeyData: | |
| """Fetch the Feather file of a feature or symbol. | |
| Args: | |
| key (hashable): Feature or symbol. | |
| path (str): Path. | |
| If `path` is None, uses `key` as the path to the Feather file. | |
| tz (any): Target timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| index_col (int, str, or sequence): Position(s) or name(s) of column(s) that should become the index. | |
| Will only apply if the fetched object has a default index. | |
| squeeze (int): Whether to squeeze a DataFrame with one column into a Series. | |
| **read_kwargs: Other keyword arguments passed to `pd.read_feather`. | |
| See https://pandas.pydata.org/docs/reference/api/pandas.read_feather.html for other arguments. | |
| For defaults, see `custom.feather` in `vectorbtpro._settings.data`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("pyarrow") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| index_col = cls.resolve_custom_setting(index_col, "index_col") | |
| if index_col is False: | |
| index_col = None | |
| squeeze = cls.resolve_custom_setting(squeeze, "squeeze") | |
| read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) | |
| if path is None: | |
| path = key | |
| obj = pd.read_feather(path, **read_kwargs) | |
| if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index): | |
| if index_col is not None: | |
| if checks.is_int(index_col): | |
| keys = obj.columns[index_col] | |
| elif isinstance(index_col, str): | |
| keys = index_col | |
| else: | |
| keys = [] | |
| for col in index_col: | |
| if checks.is_int(col): | |
| keys.append(obj.columns[col]) | |
| else: | |
| keys.append(col) | |
| obj = obj.set_index(keys) | |
| if not isinstance(obj.index, pd.MultiIndex): | |
| if obj.index.name == "index": | |
| obj.index.name = None | |
| if isinstance(obj.index, pd.DatetimeIndex) and tz is None: | |
| tz = obj.index.tz | |
| if isinstance(obj, pd.DataFrame) and squeeze: | |
| obj = obj.squeeze("columns") | |
| if isinstance(obj, pd.Series) and obj.name == "0": | |
| obj.name = None | |
| return obj, dict(tz=tz) | |
| @classmethod | |
| def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Fetch the Feather file of a feature. | |
| Uses `FeatherData.fetch_key`.""" | |
| return cls.fetch_key(feature, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Fetch the Feather file of a symbol. | |
| Uses `FeatherData.fetch_key`.""" | |
| return cls.fetch_key(symbol, **kwargs) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `FeatherData.update_key` with `key_is_feature=True`.""" | |
| return self.update_key(feature, key_is_feature=True, **kwargs) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `FeatherData.update_key` with `key_is_feature=False`.""" | |
| return self.update_key(symbol, key_is_feature=False, **kwargs) | |
| </file> | |
| <file path="data/custom/file.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `FileData`.""" | |
| import re | |
| from glob import glob | |
| from pathlib import Path | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.base import key_dict | |
| from vectorbtpro.data.custom.local import LocalData | |
| from vectorbtpro.utils import checks | |
| __all__ = [ | |
| "FileData", | |
| ] | |
| __pdoc__ = {} | |
| FileDataT = tp.TypeVar("FileDataT", bound="FileData") | |
| class FileData(LocalData): | |
| """Data class for fetching file data.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.file") | |
| @classmethod | |
| def is_dir_match(cls, path: tp.PathLike) -> bool: | |
| """Return whether a directory is a valid match.""" | |
| return False | |
| @classmethod | |
| def is_file_match(cls, path: tp.PathLike) -> bool: | |
| """Return whether a file is a valid match.""" | |
| return True | |
| @classmethod | |
| def match_path( | |
| cls, | |
| path: tp.PathLike, | |
| match_regex: tp.Optional[str] = None, | |
| sort_paths: bool = True, | |
| recursive: bool = True, | |
| extension: tp.Optional[str] = None, | |
| **kwargs, | |
| ) -> tp.List[Path]: | |
| """Get the list of all paths matching a path. | |
| If `FileData.is_dir_match` returns True for a directory, it gets returned as-is. | |
| Otherwise, iterates through all files in that directory and invokes `FileData.is_file_match`. | |
| If a pattern was provided, these methods aren't invoked.""" | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists(): | |
| if path.is_dir() and not cls.is_dir_match(path): | |
| sub_paths = [] | |
| for p in path.iterdir(): | |
| if p.is_dir() and cls.is_dir_match(p): | |
| sub_paths.append(p) | |
| if p.is_file() and cls.is_file_match(p): | |
| if extension is None or p.suffix == "." + extension: | |
| sub_paths.append(p) | |
| else: | |
| sub_paths = [path] | |
| else: | |
| sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)]) | |
| if match_regex is not None: | |
| sub_paths = [p for p in sub_paths if re.match(match_regex, str(p))] | |
| if sort_paths: | |
| sub_paths = sorted(sub_paths) | |
| return sub_paths | |
| @classmethod | |
| def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]: | |
| """List all features or symbols under a path.""" | |
| return cls.match_path(path, **match_path_kwargs) | |
| @classmethod | |
| def path_to_key(cls, path: tp.PathLike, **kwargs) -> str: | |
| """Convert a path into a feature or symbol.""" | |
| return Path(path).stem | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| paths: tp.Any = None, | |
| ) -> tp.Kwargs: | |
| return LocalData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| @classmethod | |
| def pull( | |
| cls: tp.Type[FileDataT], | |
| keys: tp.Union[tp.MaybeKeys] = None, | |
| *, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[tp.MaybeFeatures] = None, | |
| symbols: tp.Union[tp.MaybeSymbols] = None, | |
| paths: tp.Any = None, | |
| match_paths: tp.Optional[bool] = None, | |
| match_regex: tp.Optional[str] = None, | |
| sort_paths: tp.Optional[bool] = None, | |
| match_path_kwargs: tp.KwargsLike = None, | |
| path_to_key_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> FileDataT: | |
| """Override `vectorbtpro.data.base.Data.pull` to take care of paths. | |
| Use either features, symbols, or `paths` to specify the path to one or multiple files. | |
| Allowed are paths in a string or `pathlib.Path` format, or string expressions accepted by `glob.glob`. | |
| Set `match_paths` to False to not parse paths and behave like a regular | |
| `vectorbtpro.data.base.Data` instance. | |
| For defaults, see `custom.local` in `vectorbtpro._settings.data`. | |
| """ | |
| keys_meta = cls.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| paths=paths, | |
| ) | |
| keys = keys_meta["keys"] | |
| keys_are_features = keys_meta["keys_are_features"] | |
| dict_type = keys_meta["dict_type"] | |
| match_paths = cls.resolve_custom_setting(match_paths, "match_paths") | |
| match_regex = cls.resolve_custom_setting(match_regex, "match_regex") | |
| sort_paths = cls.resolve_custom_setting(sort_paths, "sort_paths") | |
| if match_paths: | |
| sync = False | |
| if paths is None: | |
| paths = keys | |
| sync = True | |
| elif keys is None: | |
| sync = True | |
| if paths is None: | |
| if keys_are_features: | |
| raise ValueError("At least features or paths must be set") | |
| else: | |
| raise ValueError("At least symbols or paths must be set") | |
| if match_path_kwargs is None: | |
| match_path_kwargs = {} | |
| if path_to_key_kwargs is None: | |
| path_to_key_kwargs = {} | |
| single_key = False | |
| if isinstance(keys, (str, Path)): | |
| # Single key | |
| keys = [keys] | |
| single_key = True | |
| single_path = False | |
| if isinstance(paths, (str, Path)): | |
| # Single path | |
| paths = [paths] | |
| single_path = True | |
| if sync: | |
| single_key = True | |
| cls.check_dict_type(paths, "paths", dict_type=dict_type) | |
| if isinstance(paths, key_dict): | |
| # Dict of path per key | |
| if sync: | |
| keys = list(paths.keys()) | |
| elif len(keys) != len(paths): | |
| if keys_are_features: | |
| raise ValueError("The number of features must be equal to the number of matched paths") | |
| else: | |
| raise ValueError("The number of symbols must be equal to the number of matched paths") | |
| elif checks.is_iterable(paths) or checks.is_sequence(paths): | |
| # Multiple paths | |
| matched_paths = [ | |
| p | |
| for sub_path in paths | |
| for p in cls.match_path( | |
| sub_path, | |
| match_regex=match_regex, | |
| sort_paths=sort_paths, | |
| **match_path_kwargs, | |
| ) | |
| ] | |
| if len(matched_paths) == 0: | |
| raise FileNotFoundError(f"No paths could be matched with {paths}") | |
| if sync: | |
| keys = [] | |
| paths = key_dict() | |
| for p in matched_paths: | |
| s = cls.path_to_key(p, **path_to_key_kwargs) | |
| keys.append(s) | |
| paths[s] = p | |
| elif len(keys) != len(matched_paths): | |
| if keys_are_features: | |
| raise ValueError("The number of features must be equal to the number of matched paths") | |
| else: | |
| raise ValueError("The number of symbols must be equal to the number of matched paths") | |
| else: | |
| paths = key_dict({s: matched_paths[i] for i, s in enumerate(keys)}) | |
| if len(matched_paths) == 1 and single_path: | |
| paths = matched_paths[0] | |
| else: | |
| raise TypeError(f"Path '{paths}' is not supported") | |
| if len(keys) == 1 and single_key: | |
| keys = keys[0] | |
| return super(FileData, cls).pull( | |
| keys, | |
| keys_are_features=keys_are_features, | |
| path=paths, | |
| **kwargs, | |
| ) | |
| </file> | |
| <file path="data/custom/finpy.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `FinPyData`.""" | |
| from itertools import product | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from findatapy.market import Market as MarketT | |
| from findatapy.util import ConfigManager as ConfigManagerT | |
| except ImportError: | |
| MarketT = "Market" | |
| ConfigManagerT = "ConfigManager" | |
| __all__ = [ | |
| "FinPyData", | |
| ] | |
| FinPyDataT = tp.TypeVar("FinPyDataT", bound="FinPyData") | |
| class FinPyData(RemoteData): | |
| """Data class for fetching using findatapy. | |
| See https://github.com/cuemacro/findatapy for API. | |
| See `FinPyData.fetch_symbol` for arguments. | |
| Usage: | |
| * Pull data (keyword argument format): | |
| ```pycon | |
| >>> data = vbt.FinPyData.pull( | |
| ... "EURUSD", | |
| ... start="14 June 2016", | |
| ... end="15 June 2016", | |
| ... timeframe="tick", | |
| ... category="fx", | |
| ... fields=["bid", "ask"], | |
| ... data_source="dukascopy" | |
| ... ) | |
| ``` | |
| * Pull data (string format): | |
| ```pycon | |
| >>> data = vbt.FinPyData.pull( | |
| ... "fx.dukascopy.tick.NYC.EURUSD.bid,ask", | |
| ... start="14 June 2016", | |
| ... end="15 June 2016", | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.finpy") | |
| @classmethod | |
| def resolve_market( | |
| cls, | |
| market: tp.Optional[MarketT] = None, | |
| **market_config, | |
| ) -> MarketT: | |
| """Resolve the market. | |
| If provided, must be of the type `findatapy.market.market.Market`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("findatapy") | |
| from findatapy.market import Market, MarketDataGenerator | |
| market = cls.resolve_custom_setting(market, "market") | |
| if market_config is None: | |
| market_config = {} | |
| has_market_config = len(market_config) > 0 | |
| market_config = cls.resolve_custom_setting(market_config, "market_config", merge=True) | |
| if "market_data_generator" not in market_config: | |
| market_config["market_data_generator"] = MarketDataGenerator() | |
| if market is None: | |
| market = Market(**market_config) | |
| elif has_market_config: | |
| raise ValueError("Cannot apply market_config to already initialized market") | |
| return market | |
| @classmethod | |
| def resolve_config_manager( | |
| cls, | |
| config_manager: tp.Optional[ConfigManagerT] = None, | |
| **config_manager_config, | |
| ) -> MarketT: | |
| """Resolve the config manager. | |
| If provided, must be of the type `findatapy.util.ConfigManager`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("findatapy") | |
| from findatapy.util import ConfigManager | |
| config_manager = cls.resolve_custom_setting(config_manager, "config_manager") | |
| if config_manager_config is None: | |
| config_manager_config = {} | |
| has_config_manager_config = len(config_manager_config) > 0 | |
| config_manager_config = cls.resolve_custom_setting(config_manager_config, "config_manager_config", merge=True) | |
| if config_manager is None: | |
| config_manager = ConfigManager().get_instance(**config_manager_config) | |
| elif has_config_manager_config: | |
| raise ValueError("Cannot apply config_manager_config to already initialized config_manager") | |
| return config_manager | |
| @classmethod | |
| def list_symbols( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| config_manager: tp.Optional[ConfigManagerT] = None, | |
| config_manager_config: tp.KwargsLike = None, | |
| category: tp.Optional[tp.MaybeList[str]] = None, | |
| data_source: tp.Optional[tp.MaybeList[str]] = None, | |
| freq: tp.Optional[tp.MaybeList[str]] = None, | |
| cut: tp.Optional[tp.MaybeList[str]] = None, | |
| tickers: tp.Optional[tp.MaybeList[str]] = None, | |
| dict_filter: tp.DictLike = None, | |
| smart_group: bool = False, | |
| return_fields: tp.Optional[tp.MaybeList[str]] = None, | |
| combine_parts: bool = True, | |
| ) -> tp.List[str]: | |
| """List all symbols. | |
| Passes most arguments to `findatapy.util.ConfigManager.free_form_tickers_regex_query`. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" | |
| if config_manager_config is None: | |
| config_manager_config = {} | |
| config_manager = cls.resolve_config_manager(config_manager=config_manager, **config_manager_config) | |
| if dict_filter is None: | |
| dict_filter = {} | |
| def_ret_fields = ["category", "data_source", "freq", "cut", "tickers"] | |
| if return_fields is None: | |
| ret_fields = def_ret_fields | |
| elif isinstance(return_fields, str): | |
| if return_fields.lower() == "all": | |
| ret_fields = def_ret_fields + ["fields"] | |
| else: | |
| ret_fields = [return_fields] | |
| else: | |
| ret_fields = return_fields | |
| df = config_manager.free_form_tickers_regex_query( | |
| category=category, | |
| data_source=data_source, | |
| freq=freq, | |
| cut=cut, | |
| tickers=tickers, | |
| dict_filter=dict_filter, | |
| smart_group=smart_group, | |
| ret_fields=ret_fields, | |
| ) | |
| all_symbols = [] | |
| for _, row in df.iterrows(): | |
| parts = [] | |
| if "category" in row.index: | |
| parts.append(row.loc["category"]) | |
| if "data_source" in row.index: | |
| parts.append(row.loc["data_source"]) | |
| if "freq" in row.index: | |
| parts.append(row.loc["freq"]) | |
| if "cut" in row.index: | |
| parts.append(row.loc["cut"]) | |
| if "tickers" in row.index: | |
| parts.append(row.loc["tickers"]) | |
| if "fields" in row.index: | |
| parts.append(row.loc["fields"]) | |
| if combine_parts: | |
| split_parts = [part.split(",") for part in parts] | |
| combinations = list(product(*split_parts)) | |
| else: | |
| combinations = [parts] | |
| for symbol in [".".join(combination) for combination in combinations]: | |
| if pattern is not None: | |
| if not cls.key_match(symbol, pattern, use_regex=use_regex): | |
| continue | |
| all_symbols.append(symbol) | |
| if sort: | |
| return sorted(dict.fromkeys(all_symbols)) | |
| return list(dict.fromkeys(all_symbols)) | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| market: tp.Optional[MarketT] = None, | |
| market_config: tp.KwargsLike = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| **request_kwargs, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from findatapy. | |
| Args: | |
| symbol (str): Symbol. | |
| Also accepts the format such as "fx.bloomberg.daily.NYC.EURUSD.close". | |
| The fields `freq`, `cut`, `tickers`, and `fields` here are optional. | |
| market (findatapy.market.market.Market): Market. | |
| See `FinPyData.resolve_market`. | |
| market_config (dict): Client config. | |
| See `FinPyData.resolve_market`. | |
| start (any): Start datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): End datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| **request_kwargs: Other keyword arguments passed to `findatapy.market.marketdatarequest.MarketDataRequest`. | |
| For defaults, see `custom.finpy` in `vectorbtpro._settings.data`. | |
| Global settings can be provided per exchange id using the `exchanges` dictionary. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("findatapy") | |
| from findatapy.market import MarketDataRequest | |
| if market_config is None: | |
| market_config = {} | |
| market = cls.resolve_market(market=market, **market_config) | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| request_kwargs = cls.resolve_custom_setting(request_kwargs, "request_kwargs", merge=True) | |
| split = dt.split_freq_str(timeframe) | |
| if split is None: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| multiplier, unit = split | |
| if unit == "s": | |
| unit = "second" | |
| freq = timeframe | |
| elif unit == "m": | |
| unit = "minute" | |
| freq = timeframe | |
| elif unit == "h": | |
| unit = "hourly" | |
| freq = timeframe | |
| elif unit == "D": | |
| unit = "daily" | |
| freq = timeframe | |
| elif unit == "W": | |
| unit = "weekly" | |
| freq = timeframe | |
| elif unit == "M": | |
| unit = "monthly" | |
| freq = timeframe | |
| elif unit == "Q": | |
| unit = "quarterly" | |
| freq = timeframe | |
| elif unit == "Y": | |
| unit = "annually" | |
| freq = timeframe | |
| else: | |
| freq = None | |
| if "resample" in request_kwargs: | |
| freq = request_kwargs["resample"] | |
| if start is not None: | |
| start = dt.to_naive_datetime(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) | |
| if end is not None: | |
| end = dt.to_naive_datetime(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) | |
| if "md_request" in request_kwargs: | |
| md_request = request_kwargs["md_request"] | |
| elif "md_request_df" in request_kwargs: | |
| md_request = market.create_md_request_from_dataframe( | |
| md_request_df=request_kwargs["md_request_df"], | |
| start_date=start, | |
| finish_date=end, | |
| freq_mult=multiplier, | |
| freq=unit, | |
| **request_kwargs, | |
| ) | |
| elif "md_request_str" in request_kwargs: | |
| md_request = market.create_md_request_from_str( | |
| md_request_str=request_kwargs["md_request_str"], | |
| start_date=start, | |
| finish_date=end, | |
| freq_mult=multiplier, | |
| freq=unit, | |
| **request_kwargs, | |
| ) | |
| elif "md_request_dict" in request_kwargs: | |
| md_request = market.create_md_request_from_dict( | |
| md_request_dict=request_kwargs["md_request_dict"], | |
| start_date=start, | |
| finish_date=end, | |
| freq_mult=multiplier, | |
| freq=unit, | |
| **request_kwargs, | |
| ) | |
| elif symbol.count(".") >= 2: | |
| md_request = market.create_md_request_from_str( | |
| md_request_str=symbol, | |
| start_date=start, | |
| finish_date=end, | |
| freq_mult=multiplier, | |
| freq=unit, | |
| **request_kwargs, | |
| ) | |
| else: | |
| md_request = MarketDataRequest( | |
| tickers=symbol, | |
| start_date=start, | |
| finish_date=end, | |
| freq_mult=multiplier, | |
| freq=unit, | |
| **request_kwargs, | |
| ) | |
| df = market.fetch_market(md_request=md_request) | |
| if df is None: | |
| return None | |
| if isinstance(md_request.tickers, str): | |
| ticker = md_request.tickers | |
| elif len(md_request.tickers) == 1: | |
| ticker = md_request.tickers[0] | |
| else: | |
| ticker = None | |
| if ticker is not None: | |
| df.columns = df.columns.map(lambda x: x.replace(ticker + ".", "")) | |
| if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: | |
| df = df.tz_localize("utc") | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/gbm_ohlc.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `GBMOHLCData`.""" | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.reshaping import broadcast_array_to | |
| from vectorbtpro.data import nb | |
| from vectorbtpro.data.custom.synthetic import SyntheticData | |
| from vectorbtpro.ohlcv import nb as ohlcv_nb | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.random_ import set_seed | |
| from vectorbtpro.utils.template import substitute_templates | |
| __all__ = [ | |
| "GBMOHLCData", | |
| ] | |
| __pdoc__ = {} | |
| class GBMOHLCData(SyntheticData): | |
| """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_1d_nb` | |
| and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.gbm_ohlc") | |
| @classmethod | |
| def generate_symbol( | |
| cls, | |
| symbol: tp.Symbol, | |
| index: tp.Index, | |
| n_ticks: tp.Optional[tp.ArrayLike] = None, | |
| start_value: tp.Optional[float] = None, | |
| mean: tp.Optional[float] = None, | |
| std: tp.Optional[float] = None, | |
| dt: tp.Optional[float] = None, | |
| seed: tp.Optional[int] = None, | |
| jitted: tp.JittedOption = None, | |
| template_context: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.SymbolData: | |
| """Generate a symbol. | |
| Args: | |
| symbol (hashable): Symbol. | |
| index (pd.Index): Pandas index. | |
| n_ticks (int or array_like): Number of ticks per bar. | |
| Flexible argument. Can be a template with a context containing `symbol` and `index`. | |
| start_value (float): Value at time 0. | |
| Does not appear as the first value in the output data. | |
| mean (float): Drift, or mean of the percentage change. | |
| std (float): Standard deviation of the percentage change. | |
| dt (float): Time change (one period of time). | |
| seed (int): Seed to make output deterministic. | |
| jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. | |
| template_context (dict): Context used to substitute templates. | |
| For defaults, see `custom.gbm` in `vectorbtpro._settings.data`. | |
| !!! note | |
| When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`. | |
| """ | |
| n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks") | |
| template_context = merge_dicts(dict(symbol=symbol, index=index), template_context) | |
| n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks") | |
| n_ticks = broadcast_array_to(n_ticks, len(index)) | |
| start_value = cls.resolve_custom_setting(start_value, "start_value") | |
| mean = cls.resolve_custom_setting(mean, "mean") | |
| std = cls.resolve_custom_setting(std, "std") | |
| dt = cls.resolve_custom_setting(dt, "dt") | |
| seed = cls.resolve_custom_setting(seed, "seed") | |
| if seed is not None: | |
| set_seed(seed) | |
| func = jit_reg.resolve_option(nb.generate_gbm_data_1d_nb, jitted) | |
| ticks = func( | |
| np.sum(n_ticks), | |
| start_value=start_value, | |
| mean=mean, | |
| std=std, | |
| dt=dt, | |
| ) | |
| func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted) | |
| out = func(ticks, n_ticks) | |
| return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"]) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| _ = fetch_kwargs.pop("start_value", None) | |
| start_value = self.data[symbol]["Open"].iloc[-1] | |
| fetch_kwargs["seed"] = None | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, start_value=start_value, **kwargs) | |
| </file> | |
| <file path="data/custom/gbm.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `GBMData`.""" | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.reshaping import to_1d_array | |
| from vectorbtpro.data import nb | |
| from vectorbtpro.data.custom.synthetic import SyntheticData | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.random_ import set_seed | |
| __all__ = [ | |
| "GBMData", | |
| ] | |
| __pdoc__ = {} | |
| class GBMData(SyntheticData): | |
| """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_nb`.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.gbm") | |
| @classmethod | |
| def generate_key( | |
| cls, | |
| key: tp.Key, | |
| index: tp.Index, | |
| columns: tp.Union[tp.Hashable, tp.IndexLike] = None, | |
| start_value: tp.Optional[float] = None, | |
| mean: tp.Optional[float] = None, | |
| std: tp.Optional[float] = None, | |
| dt: tp.Optional[float] = None, | |
| seed: tp.Optional[int] = None, | |
| jitted: tp.JittedOption = None, | |
| **kwargs, | |
| ) -> tp.KeyData: | |
| """Generate a feature or symbol. | |
| Args: | |
| key (hashable): Feature or symbol. | |
| index (pd.Index): Pandas index. | |
| columns (hashable or index_like): Column names. | |
| Provide a single value (hashable) to make a Series. | |
| start_value (float): Value at time 0. | |
| Does not appear as the first value in the output data. | |
| mean (float): Drift, or mean of the percentage change. | |
| std (float): Standard deviation of the percentage change. | |
| dt (float): Time change (one period of time). | |
| seed (int): Seed to make output deterministic. | |
| jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. | |
| For defaults, see `custom.gbm` in `vectorbtpro._settings.data`. | |
| !!! note | |
| When setting a seed, remember to pass a seed per feature/symbol using | |
| `vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally | |
| `vectorbtpro.data.base.key_dict`. | |
| """ | |
| if checks.is_hashable(columns): | |
| columns = [columns] | |
| make_series = True | |
| else: | |
| make_series = False | |
| if not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| start_value = cls.resolve_custom_setting(start_value, "start_value") | |
| mean = cls.resolve_custom_setting(mean, "mean") | |
| std = cls.resolve_custom_setting(std, "std") | |
| dt = cls.resolve_custom_setting(dt, "dt") | |
| seed = cls.resolve_custom_setting(seed, "seed") | |
| if seed is not None: | |
| set_seed(seed) | |
| func = jit_reg.resolve_option(nb.generate_gbm_data_nb, jitted) | |
| out = func( | |
| (len(index), len(columns)), | |
| start_value=to_1d_array(start_value), | |
| mean=to_1d_array(mean), | |
| std=to_1d_array(std), | |
| dt=to_1d_array(dt), | |
| ) | |
| if make_series: | |
| return pd.Series(out[:, 0], index=index, name=columns[0]) | |
| return pd.DataFrame(out, index=index, columns=columns) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| fetch_kwargs["start"] = self.select_last_index(key) | |
| _ = fetch_kwargs.pop("start_value", None) | |
| start_value = self.data[key].iloc[-2] | |
| fetch_kwargs["seed"] = None | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, start_value=start_value, **kwargs) | |
| return self.fetch_symbol(key, start_value=start_value, **kwargs) | |
| </file> | |
| <file path="data/custom/hdf.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `HDFData`.""" | |
| import re | |
| from glob import glob | |
| from pathlib import Path, PurePath | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.file import FileData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.parsing import get_func_arg_names | |
| __all__ = [ | |
| "HDFData", | |
| ] | |
| __pdoc__ = {} | |
| class HDFPathNotFoundError(Exception): | |
| """Gets raised if the path to an HDF file could not be found.""" | |
| pass | |
| class HDFKeyNotFoundError(Exception): | |
| """Gets raised if the key to an HDF object could not be found.""" | |
| pass | |
| HDFDataT = tp.TypeVar("HDFDataT", bound="HDFData") | |
| class HDFData(FileData): | |
| """Data class for fetching HDF data using PyTables.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.hdf") | |
| @classmethod | |
| def is_hdf_file(cls, path: tp.PathLike) -> bool: | |
| """Return whether the path is an HDF file.""" | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists() and path.is_file() and ".hdf" in path.suffixes: | |
| return True | |
| if path.exists() and path.is_file() and ".hdf5" in path.suffixes: | |
| return True | |
| if path.exists() and path.is_file() and ".h5" in path.suffixes: | |
| return True | |
| return False | |
| @classmethod | |
| def is_file_match(cls, path: tp.PathLike) -> bool: | |
| return cls.is_hdf_file(path) | |
| @classmethod | |
| def split_hdf_path( | |
| cls, | |
| path: tp.PathLike, | |
| key: tp.Optional[str] = None, | |
| _full_path: tp.Optional[Path] = None, | |
| ) -> tp.Tuple[Path, tp.Optional[str]]: | |
| """Split the path to an HDF object into the path to the file and the key.""" | |
| path = Path(path) | |
| if _full_path is None: | |
| _full_path = path | |
| if path.exists(): | |
| if path.is_dir(): | |
| raise HDFPathNotFoundError(f"No HDF files could be matched with {_full_path}") | |
| return path, key | |
| new_path = path.parent | |
| if key is None: | |
| new_key = path.name | |
| else: | |
| new_key = str(Path(path.name) / key) | |
| return cls.split_hdf_path(new_path, new_key, _full_path=_full_path) | |
| @classmethod | |
| def match_path( | |
| cls, | |
| path: tp.PathLike, | |
| match_regex: tp.Optional[str] = None, | |
| sort_paths: bool = True, | |
| recursive: bool = True, | |
| **kwargs, | |
| ) -> tp.List[Path]: | |
| """Override `FileData.match_path` to return a list of HDF paths | |
| (path to file + key) matching a path.""" | |
| path = Path(path) | |
| if path.exists(): | |
| if path.is_dir() and not cls.is_dir_match(path): | |
| sub_paths = [] | |
| for p in path.iterdir(): | |
| if p.is_dir() and cls.is_dir_match(p): | |
| sub_paths.append(p) | |
| if p.is_file() and cls.is_file_match(p): | |
| sub_paths.append(p) | |
| key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)] | |
| else: | |
| with pd.HDFStore(str(path), mode="r") as store: | |
| keys = [k[1:] for k in store.keys()] | |
| key_paths = [path / k for k in keys] | |
| else: | |
| try: | |
| file_path, key = cls.split_hdf_path(path) | |
| with pd.HDFStore(str(file_path), mode="r") as store: | |
| keys = [k[1:] for k in store.keys()] | |
| if key is None: | |
| key_paths = [file_path / k for k in keys] | |
| elif key in keys: | |
| key_paths = [file_path / key] | |
| else: | |
| matching_keys = [] | |
| for k in keys: | |
| if k.startswith(key) or PurePath("/" + str(k)).match("/" + str(key)): | |
| matching_keys.append(k) | |
| if len(matching_keys) == 0: | |
| raise HDFKeyNotFoundError(f"No HDF keys could be matched with {key}") | |
| key_paths = [file_path / k for k in matching_keys] | |
| except HDFPathNotFoundError: | |
| sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)]) | |
| if len(sub_paths) == 0 and re.match(r".+\..+", str(path)): | |
| base_path = None | |
| base_ended = False | |
| key_path = None | |
| for part in path.parts: | |
| part = Path(part) | |
| if base_ended: | |
| if key_path is None: | |
| key_path = part | |
| else: | |
| key_path /= part | |
| else: | |
| if re.match(r".+\..+", str(part)): | |
| base_ended = True | |
| if base_path is None: | |
| base_path = part | |
| else: | |
| base_path /= part | |
| sub_paths = list([Path(p) for p in glob(str(base_path), recursive=recursive)]) | |
| if key_path is not None: | |
| sub_paths = [p / key_path for p in sub_paths] | |
| key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)] | |
| if match_regex is not None: | |
| key_paths = [p for p in key_paths if re.match(match_regex, str(p))] | |
| if sort_paths: | |
| key_paths = sorted(key_paths) | |
| return key_paths | |
| @classmethod | |
| def path_to_key(cls, path: tp.PathLike, **kwargs) -> str: | |
| return Path(path).name | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| paths: tp.Any = None, | |
| ) -> tp.Kwargs: | |
| keys_meta = FileData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| if keys_meta["keys"] is None and paths is None: | |
| keys_meta["keys"] = cls.list_paths() | |
| return keys_meta | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: tp.Key, | |
| path: tp.Any = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| tz: tp.TimezoneLike = None, | |
| start_row: tp.Optional[int] = None, | |
| end_row: tp.Optional[int] = None, | |
| chunk_func: tp.Optional[tp.Callable] = None, | |
| **read_kwargs, | |
| ) -> tp.KeyData: | |
| """Fetch the HDF object of a feature or symbol. | |
| Args: | |
| key (hashable): Feature or symbol. | |
| path (str): Path. | |
| Will be resolved with `HDFData.split_hdf_path`. | |
| If `path` is None, uses `key` as the path to the HDF file. | |
| start (any): Start datetime. | |
| Will extract the object's index and compare the index to the date. | |
| Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. | |
| !!! note | |
| Can only be used if the object was saved in the table format! | |
| end (any): End datetime. | |
| Will extract the object's index and compare the index to the date. | |
| Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. | |
| !!! note | |
| Can only be used if the object was saved in the table format! | |
| tz (any): Target timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| start_row (int): Start row (inclusive). | |
| Will use it when querying index as well. | |
| end_row (int): End row (exclusive). | |
| Will use it when querying index as well. | |
| chunk_func (callable): Function to select and concatenate chunks from `TableIterator`. | |
| Gets called only if `iterator` or `chunksize` are set. | |
| **read_kwargs: Other keyword arguments passed to `pd.read_hdf`. | |
| See https://pandas.pydata.org/docs/reference/api/pandas.read_hdf.html for other arguments. | |
| For defaults, see `custom.hdf` in `vectorbtpro._settings.data`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("tables") | |
| from pandas.io.pytables import TableIterator | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| start_row = cls.resolve_custom_setting(start_row, "start_row") | |
| if start_row is None: | |
| start_row = 0 | |
| end_row = cls.resolve_custom_setting(end_row, "end_row") | |
| read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) | |
| if path is None: | |
| path = key | |
| path = Path(path) | |
| file_path, file_key = cls.split_hdf_path(path) | |
| if file_key is not None: | |
| key = file_key | |
| if start is not None or end is not None: | |
| hdf_store_arg_names = get_func_arg_names(pd.HDFStore.__init__) | |
| hdf_store_kwargs = dict() | |
| for k, v in read_kwargs.items(): | |
| if k in hdf_store_arg_names: | |
| hdf_store_kwargs[k] = v | |
| with pd.HDFStore(str(file_path), mode="r", **hdf_store_kwargs) as store: | |
| index = store.select_column(key, "index", start=start_row, stop=end_row) | |
| if not isinstance(index, pd.Index): | |
| index = pd.Index(index) | |
| if not isinstance(index, pd.DatetimeIndex): | |
| raise TypeError("Cannot filter index that is not DatetimeIndex") | |
| if tz is None: | |
| tz = index.tz | |
| if index.tz is not None: | |
| if start is not None: | |
| start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=index.tz) | |
| if end is not None: | |
| end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=index.tz) | |
| else: | |
| if start is not None: | |
| start = dt.to_naive_timestamp(start, tz=tz) | |
| if end is not None: | |
| end = dt.to_naive_timestamp(end, tz=tz) | |
| mask = True | |
| if start is not None: | |
| mask &= index >= start | |
| if end is not None: | |
| mask &= index < end | |
| mask_indices = np.flatnonzero(mask) | |
| if len(mask_indices) == 0: | |
| return None | |
| start_row += mask_indices[0] | |
| end_row = start_row + mask_indices[-1] - mask_indices[0] + 1 | |
| obj = pd.read_hdf(file_path, key=key, start=start_row, stop=end_row, **read_kwargs) | |
| if isinstance(obj, TableIterator): | |
| if chunk_func is None: | |
| obj = pd.concat(list(obj), axis=0) | |
| else: | |
| obj = chunk_func(obj) | |
| if isinstance(obj.index, pd.DatetimeIndex) and tz is None: | |
| tz = obj.index.tz | |
| return obj, dict(last_row=start_row + len(obj.index) - 1, tz=tz) | |
| @classmethod | |
| def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Fetch the HDF object of a feature. | |
| Uses `HDFData.fetch_key`.""" | |
| return cls.fetch_key(feature, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Load the HDF object for a symbol. | |
| Uses `HDFData.fetch_key`.""" | |
| return cls.fetch_key(symbol, **kwargs) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| returned_kwargs = self.select_returned_kwargs(key) | |
| fetch_kwargs["start_row"] = returned_kwargs["last_row"] | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `HDFData.update_key` with `key_is_feature=True`.""" | |
| return self.update_key(feature, key_is_feature=True, **kwargs) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `HDFData.update_key` with `key_is_feature=False`.""" | |
| return self.update_key(symbol, key_is_feature=False, **kwargs) | |
| </file> | |
| <file path="data/custom/local.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `LocalData`.""" | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.custom import CustomData | |
| __all__ = [ | |
| "LocalData", | |
| ] | |
| __pdoc__ = {} | |
| class LocalData(CustomData): | |
| """Data class for fetching local data.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.local") | |
| </file> | |
| <file path="data/custom/ndl.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `NDLData`.""" | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| __all__ = [ | |
| "NDLData", | |
| ] | |
| __pdoc__ = {} | |
| NDLDataT = tp.TypeVar("NDLDataT", bound="NDLData") | |
| class NDLData(RemoteData): | |
| """Data class for fetching from Nasdaq Data Link. | |
| See https://github.com/Nasdaq/data-link-python for API. | |
| See `NDLData.fetch_symbol` for arguments. | |
| Usage: | |
| * Set up the API key globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.NDLData.set_custom_settings( | |
| ... api_key="YOUR_KEY" | |
| ... ) | |
| ``` | |
| * Pull a dataset: | |
| ```pycon | |
| >>> data = vbt.NDLData.pull( | |
| ... "FRED/GDP", | |
| ... start="2001-12-31", | |
| ... end="2005-12-31" | |
| ... ) | |
| ``` | |
| * Pull a datatable: | |
| ```pycon | |
| >>> data = vbt.NDLData.pull( | |
| ... "MER/F1", | |
| ... data_format="datatable", | |
| ... compnumber="39102", | |
| ... paginate=True | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.ndl") | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| api_key: tp.Optional[str] = None, | |
| data_format: tp.Optional[str] = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| tz: tp.TimezoneLike = None, | |
| column_indices: tp.Optional[tp.MaybeIterable[int]] = None, | |
| **params, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Nasdaq Data Link. | |
| Args: | |
| symbol (str): Symbol. | |
| api_key (str): API key. | |
| data_format (str): Data format. | |
| Supported are "dataset" and "datatable". | |
| start (any): Retrieve data rows on and after the specified start date. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): Retrieve data rows up to and including the specified end date. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| column_indices (int or iterable): Request one or more specific columns. | |
| Column 0 is the date column and is always returned. Data begins at column 1. | |
| **params: Keyword arguments sent as field/value params to Nasdaq Data Link with no interference. | |
| For defaults, see `custom.ndl` in `vectorbtpro._settings.data`. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("nasdaqdatalink") | |
| import nasdaqdatalink | |
| api_key = cls.resolve_custom_setting(api_key, "api_key") | |
| data_format = cls.resolve_custom_setting(data_format, "data_format") | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| column_indices = cls.resolve_custom_setting(column_indices, "column_indices") | |
| if column_indices is not None: | |
| if isinstance(column_indices, int): | |
| dataset = symbol + "." + str(column_indices) | |
| else: | |
| dataset = [symbol + "." + str(index) for index in column_indices] | |
| else: | |
| dataset = symbol | |
| params = cls.resolve_custom_setting(params, "params", merge=True) | |
| # Establish the timestamps | |
| if start is not None: | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") | |
| start_date = pd.Timestamp(start).isoformat() | |
| if "start_date" not in params: | |
| params["start_date"] = start_date | |
| else: | |
| start_date = None | |
| if end is not None: | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") | |
| end_date = pd.Timestamp(end).isoformat() | |
| if "end_date" not in params: | |
| params["end_date"] = end_date | |
| else: | |
| end_date = None | |
| # Collect and format the data | |
| if data_format.lower() == "dataset": | |
| df = nasdaqdatalink.get( | |
| dataset, | |
| api_key=api_key, | |
| **params, | |
| ) | |
| else: | |
| df = nasdaqdatalink.get_table( | |
| dataset, | |
| api_key=api_key, | |
| **params, | |
| ) | |
| new_columns = [] | |
| for c in df.columns: | |
| new_c = c | |
| if isinstance(symbol, str): | |
| new_c = new_c.replace(symbol + " - ", "") | |
| if new_c == "Last": | |
| new_c = "Close" | |
| new_columns.append(new_c) | |
| df = df.rename(columns=dict(zip(df.columns, new_columns))) | |
| if df.index.name == "None": | |
| df.index.name = None | |
| if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: | |
| df = df.tz_localize("utc") | |
| if isinstance(df.index, pd.DatetimeIndex) and not df.empty: | |
| if start is not None: | |
| start = dt.to_timestamp(start, tz=df.index.tz) | |
| if df.index[0] < start: | |
| df = df[df.index >= start] | |
| if end is not None: | |
| end = dt.to_timestamp(end, tz=df.index.tz) | |
| if df.index[-1] >= end: | |
| df = df[df.index < end] | |
| return df, dict(tz=tz) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/parquet.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `ParquetData`.""" | |
| import re | |
| from pathlib import Path | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.file import FileData | |
| from vectorbtpro.utils.config import merge_dicts | |
| __all__ = [ | |
| "ParquetData", | |
| ] | |
| __pdoc__ = {} | |
| ParquetDataT = tp.TypeVar("ParquetDataT", bound="ParquetData") | |
| class ParquetData(FileData): | |
| """Data class for fetching Parquet data using PyArrow or FastParquet.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.parquet") | |
| @classmethod | |
| def is_parquet_file(cls, path: tp.PathLike) -> bool: | |
| """Return whether the path is a Parquet file.""" | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists() and path.is_file() and ".parquet" in path.suffixes: | |
| return True | |
| return False | |
| @classmethod | |
| def is_parquet_group_dir(cls, path: tp.PathLike) -> bool: | |
| """Return whether the path is a directory that is a group of Parquet partitions. | |
| !!! note | |
| Assumes the Hive partitioning scheme.""" | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists() and path.is_dir(): | |
| partition_regex = r"^(.+)=(.+)" | |
| if re.match(partition_regex, path.name): | |
| for p in path.iterdir(): | |
| if cls.is_parquet_group_dir(p) or cls.is_parquet_file(p): | |
| return True | |
| return False | |
| @classmethod | |
| def is_parquet_dir(cls, path: tp.PathLike) -> bool: | |
| """Return whether the path is a directory that is a group itself or | |
| contains groups of Parquet partitions.""" | |
| if cls.is_parquet_group_dir(path): | |
| return True | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| if path.exists() and path.is_dir(): | |
| for p in path.iterdir(): | |
| if cls.is_parquet_group_dir(p): | |
| return True | |
| return False | |
| @classmethod | |
| def is_dir_match(cls, path: tp.PathLike) -> bool: | |
| return cls.is_parquet_dir(path) | |
| @classmethod | |
| def is_file_match(cls, path: tp.PathLike) -> bool: | |
| return cls.is_parquet_file(path) | |
| @classmethod | |
| def list_partition_cols(cls, path: tp.PathLike) -> tp.List[str]: | |
| """List partitioning columns under a path. | |
| !!! note | |
| Assumes the Hive partitioning scheme.""" | |
| if not isinstance(path, Path): | |
| path = Path(path) | |
| partition_cols = [] | |
| found_last_level = False | |
| while not found_last_level: | |
| found_new_level = False | |
| for p in path.iterdir(): | |
| if cls.is_parquet_group_dir(p): | |
| partition_cols.append(p.name.split("=")[0]) | |
| path = p | |
| found_new_level = True | |
| break | |
| if not found_new_level: | |
| found_last_level = True | |
| return partition_cols | |
| @classmethod | |
| def is_default_partition_col(cls, level: str) -> bool: | |
| """Return whether a partitioning column is a default partitioning column.""" | |
| return re.match(r"^(\bgroup\b)|(group_\d+)", level) is not None | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| paths: tp.Any = None, | |
| ) -> tp.Kwargs: | |
| keys_meta = FileData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| if keys_meta["keys"] is None and paths is None: | |
| keys_meta["keys"] = cls.list_paths() | |
| return keys_meta | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: tp.Key, | |
| path: tp.Any = None, | |
| tz: tp.TimezoneLike = None, | |
| squeeze: tp.Optional[bool] = None, | |
| keep_partition_cols: tp.Optional[bool] = None, | |
| engine: tp.Optional[str] = None, | |
| **read_kwargs, | |
| ) -> tp.KeyData: | |
| """Fetch the Parquet file of a feature or symbol. | |
| Args: | |
| key (hashable): Feature or symbol. | |
| path (str): Path. | |
| If `path` is None, uses `key` as the path to the Parquet file. | |
| tz (any): Target timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| squeeze (int): Whether to squeeze a DataFrame with one column into a Series. | |
| keep_partition_cols (bool): Whether to return partitioning columns (if any). | |
| If None, will remove any partitioning column that is "group" or "group_{index}". | |
| Retrieves the list of partitioning columns with `ParquetData.list_partition_cols`. | |
| engine (str): See `pd.read_parquet`. | |
| **read_kwargs: Other keyword arguments passed to `pd.read_parquet`. | |
| See https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html for other arguments. | |
| For defaults, see `custom.parquet` in `vectorbtpro._settings.data`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import, assert_can_import_any | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| squeeze = cls.resolve_custom_setting(squeeze, "squeeze") | |
| keep_partition_cols = cls.resolve_custom_setting(keep_partition_cols, "keep_partition_cols") | |
| engine = cls.resolve_custom_setting(engine, "engine") | |
| read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) | |
| if engine == "pyarrow": | |
| assert_can_import("pyarrow") | |
| elif engine == "fastparquet": | |
| assert_can_import("fastparquet") | |
| elif engine == "auto": | |
| assert_can_import_any("pyarrow", "fastparquet") | |
| else: | |
| raise ValueError(f"Invalid engine: '{engine}'") | |
| if path is None: | |
| path = key | |
| obj = pd.read_parquet(path, engine=engine, **read_kwargs) | |
| if keep_partition_cols in (None, False): | |
| if cls.is_parquet_dir(path): | |
| drop_columns = [] | |
| partition_cols = cls.list_partition_cols(path) | |
| for col in obj.columns: | |
| if col in partition_cols: | |
| if keep_partition_cols is False or cls.is_default_partition_col(col): | |
| drop_columns.append(col) | |
| obj = obj.drop(drop_columns, axis=1) | |
| if isinstance(obj.index, pd.DatetimeIndex) and tz is None: | |
| tz = obj.index.tz | |
| if isinstance(obj, pd.DataFrame) and squeeze: | |
| obj = obj.squeeze("columns") | |
| if isinstance(obj, pd.Series) and obj.name == "0": | |
| obj.name = None | |
| return obj, dict(tz=tz) | |
| @classmethod | |
| def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Fetch the Parquet file of a feature. | |
| Uses `ParquetData.fetch_key`.""" | |
| return cls.fetch_key(feature, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Fetch the Parquet file of a symbol. | |
| Uses `ParquetData.fetch_key`.""" | |
| return cls.fetch_key(symbol, **kwargs) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `ParquetData.update_key` with `key_is_feature=True`.""" | |
| return self.update_key(feature, key_is_feature=True, **kwargs) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `ParquetData.update_key` with `key_is_feature=False`.""" | |
| return self.update_key(symbol, key_is_feature=False, **kwargs) | |
| </file> | |
| <file path="data/custom/polygon.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `PolygonData`.""" | |
| import time | |
| import traceback | |
| from functools import wraps, partial | |
| import pandas as pd | |
| import requests | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.pbar import ProgressBar | |
| from vectorbtpro.utils.warnings_ import warn | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from polygon import RESTClient as PolygonClientT | |
| except ImportError: | |
| PolygonClientT = "PolygonClient" | |
| __all__ = [ | |
| "PolygonData", | |
| ] | |
| PolygonDataT = tp.TypeVar("PolygonDataT", bound="PolygonData") | |
| class PolygonData(RemoteData): | |
| """Data class for fetching from Polygon. | |
| See https://github.com/polygon-io/client-python for API. | |
| See `PolygonData.fetch_symbol` for arguments. | |
| Usage: | |
| * Set up the API key globally: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.PolygonData.set_custom_settings( | |
| ... client_config=dict( | |
| ... api_key="YOUR_KEY" | |
| ... ) | |
| ... ) | |
| ``` | |
| * Pull stock data: | |
| ```pycon | |
| >>> data = vbt.PolygonData.pull( | |
| ... "AAPL", | |
| ... start="2021-01-01", | |
| ... end="2022-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| * Pull crypto data: | |
| ```pycon | |
| >>> data = vbt.PolygonData.pull( | |
| ... "X:BTCUSD", | |
| ... start="2021-01-01", | |
| ... end="2022-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.polygon") | |
| @classmethod | |
| def list_symbols( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| client: tp.Optional[PolygonClientT] = None, | |
| client_config: tp.DictLike = None, | |
| **list_tickers_kwargs, | |
| ) -> tp.List[str]: | |
| """List all symbols. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`. | |
| For supported keyword arguments, see `polygon.RESTClient.list_tickers`.""" | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| all_symbols = [] | |
| for ticker in client.list_tickers(**list_tickers_kwargs): | |
| symbol = ticker.ticker | |
| if pattern is not None: | |
| if not cls.key_match(symbol, pattern, use_regex=use_regex): | |
| continue | |
| all_symbols.append(symbol) | |
| if sort: | |
| return sorted(dict.fromkeys(all_symbols)) | |
| return list(dict.fromkeys(all_symbols)) | |
| @classmethod | |
| def resolve_client(cls, client: tp.Optional[PolygonClientT] = None, **client_config) -> PolygonClientT: | |
| """Resolve the client. | |
| If provided, must be of the type `polygon.rest.RESTClient`. | |
| Otherwise, will be created using `client_config`.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("polygon") | |
| from polygon import RESTClient | |
| client = cls.resolve_custom_setting(client, "client") | |
| if client_config is None: | |
| client_config = {} | |
| has_client_config = len(client_config) > 0 | |
| client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) | |
| if client is None: | |
| client = RESTClient(**client_config) | |
| elif has_client_config: | |
| raise ValueError("Cannot apply client_config to already initialized client") | |
| return client | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| client: tp.Optional[PolygonClientT] = None, | |
| client_config: tp.DictLike = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| adjusted: tp.Optional[bool] = None, | |
| limit: tp.Optional[int] = None, | |
| params: tp.KwargsLike = None, | |
| delay: tp.Optional[float] = None, | |
| retries: tp.Optional[int] = None, | |
| show_progress: tp.Optional[bool] = None, | |
| pbar_kwargs: tp.KwargsLike = None, | |
| silence_warnings: tp.Optional[bool] = None, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Polygon. | |
| Args: | |
| symbol (str): Symbol. | |
| Supports the following APIs: | |
| * Stocks and equities | |
| * Currencies - symbol must have the prefix `C:` | |
| * Crypto - symbol must have the prefix `X:` | |
| client (polygon.rest.RESTClient): Client. | |
| See `PolygonData.resolve_client`. | |
| client_config (dict): Client config. | |
| See `PolygonData.resolve_client`. | |
| start (any): The start of the aggregate time window. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): The end of the aggregate time window. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| adjusted (str): Whether the results are adjusted for splits. | |
| By default, results are adjusted. | |
| Set this to False to get results that are NOT adjusted for splits. | |
| limit (int): Limits the number of base aggregates queried to create the aggregate results. | |
| Max 50000 and Default 5000. | |
| params (dict): Any additional query params. | |
| delay (float): Time to sleep after each request (in seconds). | |
| retries (int): The number of retries on failure to fetch data. | |
| show_progress (bool): Whether to show the progress bar. | |
| pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`. | |
| silence_warnings (bool): Whether to silence all warnings. | |
| For defaults, see `custom.polygon` in `vectorbtpro._settings.data`. | |
| !!! note | |
| If you're using a free plan that has an API rate limit of several requests per minute, | |
| make sure to set `delay` to a higher number, such as 12 (which makes 5 requests per minute). | |
| """ | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| adjusted = cls.resolve_custom_setting(adjusted, "adjusted") | |
| limit = cls.resolve_custom_setting(limit, "limit") | |
| params = cls.resolve_custom_setting(params, "params", merge=True) | |
| delay = cls.resolve_custom_setting(delay, "delay") | |
| retries = cls.resolve_custom_setting(retries, "retries") | |
| show_progress = cls.resolve_custom_setting(show_progress, "show_progress") | |
| pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True) | |
| if "bar_id" not in pbar_kwargs: | |
| pbar_kwargs["bar_id"] = "polygon" | |
| silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings") | |
| # Resolve the timeframe | |
| if not isinstance(timeframe, str): | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| split = dt.split_freq_str(timeframe) | |
| if split is None: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| multiplier, unit = split | |
| if unit == "m": | |
| unit = "minute" | |
| elif unit == "h": | |
| unit = "hour" | |
| elif unit == "D": | |
| unit = "day" | |
| elif unit == "W": | |
| unit = "week" | |
| elif unit == "M": | |
| unit = "month" | |
| elif unit == "Q": | |
| unit = "quarter" | |
| elif unit == "Y": | |
| unit = "year" | |
| # Establish the timestamps | |
| if start is not None: | |
| start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) | |
| else: | |
| start_ts = None | |
| if end is not None: | |
| end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) | |
| else: | |
| end_ts = None | |
| prev_end_ts = None | |
| def _retry(method): | |
| @wraps(method) | |
| def retry_method(*args, **kwargs): | |
| for i in range(retries): | |
| try: | |
| return method(*args, **kwargs) | |
| except requests.exceptions.HTTPError as e: | |
| if isinstance(e, requests.exceptions.HTTPError) and e.response.status_code == 429: | |
| if not silence_warnings: | |
| warn(traceback.format_exc()) | |
| # Polygon.io API rate limit is per minute | |
| warn("Waiting 1 minute...") | |
| time.sleep(60) | |
| else: | |
| raise e | |
| except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: | |
| if i == retries - 1: | |
| raise e | |
| if not silence_warnings: | |
| warn(traceback.format_exc()) | |
| if delay is not None: | |
| time.sleep(delay) | |
| return retry_method | |
| def _postprocess(agg): | |
| return dict( | |
| o=agg.open, | |
| h=agg.high, | |
| l=agg.low, | |
| c=agg.close, | |
| v=agg.volume, | |
| vw=agg.vwap, | |
| t=agg.timestamp, | |
| n=agg.transactions, | |
| ) | |
| @_retry | |
| def _fetch(_start_ts, _limit): | |
| return list( | |
| map( | |
| _postprocess, | |
| client.get_aggs( | |
| ticker=symbol, | |
| multiplier=multiplier, | |
| timespan=unit, | |
| from_=_start_ts, | |
| to=end_ts, | |
| adjusted=adjusted, | |
| sort="asc", | |
| limit=_limit, | |
| params=params, | |
| raw=False, | |
| ), | |
| ) | |
| ) | |
| def _ts_to_str(ts: tp.Optional[int]) -> str: | |
| if ts is None: | |
| return "?" | |
| return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe) | |
| def _filter_func(d: tp.Dict, _prev_end_ts: tp.Optional[int] = None) -> bool: | |
| if start_ts is not None: | |
| if d["t"] < start_ts: | |
| return False | |
| if _prev_end_ts is not None: | |
| if d["t"] <= _prev_end_ts: | |
| return False | |
| if end_ts is not None: | |
| if d["t"] >= end_ts: | |
| return False | |
| return True | |
| # Iteratively collect the data | |
| data = [] | |
| try: | |
| with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar: | |
| pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts))) | |
| while True: | |
| # Fetch the klines for the next timeframe | |
| next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit) | |
| next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data)) | |
| # Update the timestamps and the progress bar | |
| if not len(next_data): | |
| break | |
| data += next_data | |
| if start_ts is None: | |
| start_ts = next_data[0]["t"] | |
| pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1]["t"]))) | |
| pbar.update() | |
| prev_end_ts = next_data[-1]["t"] | |
| if end_ts is not None and prev_end_ts >= end_ts: | |
| break | |
| if delay is not None: | |
| time.sleep(delay) # be kind to api | |
| except Exception as e: | |
| if not silence_warnings: | |
| warn(traceback.format_exc()) | |
| warn( | |
| f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. " | |
| "Use update() method to fetch missing data." | |
| ) | |
| df = pd.DataFrame(data) | |
| df = df[["t", "o", "h", "l", "c", "v", "n", "vw"]] | |
| df = df.rename( | |
| columns={ | |
| "t": "Open time", | |
| "o": "Open", | |
| "h": "High", | |
| "l": "Low", | |
| "c": "Close", | |
| "v": "Volume", | |
| "n": "Trade count", | |
| "vw": "VWAP", | |
| } | |
| ) | |
| df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True) | |
| del df["Open time"] | |
| if "Open" in df.columns: | |
| df["Open"] = df["Open"].astype(float) | |
| if "High" in df.columns: | |
| df["High"] = df["High"].astype(float) | |
| if "Low" in df.columns: | |
| df["Low"] = df["Low"].astype(float) | |
| if "Close" in df.columns: | |
| df["Close"] = df["Close"].astype(float) | |
| if "Volume" in df.columns: | |
| df["Volume"] = df["Volume"].astype(float) | |
| if "Trade count" in df.columns: | |
| df["Trade count"] = df["Trade count"].astype(int, errors="ignore") | |
| if "VWAP" in df.columns: | |
| df["VWAP"] = df["VWAP"].astype(float) | |
| return df, dict(tz=tz, freq=timeframe) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/random_ohlc.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `RandomOHLCData`.""" | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.reshaping import broadcast_array_to | |
| from vectorbtpro.data import nb | |
| from vectorbtpro.data.custom.synthetic import SyntheticData | |
| from vectorbtpro.ohlcv import nb as ohlcv_nb | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.random_ import set_seed | |
| from vectorbtpro.utils.template import substitute_templates | |
| __all__ = [ | |
| "RandomOHLCData", | |
| ] | |
| __pdoc__ = {} | |
| class RandomOHLCData(SyntheticData): | |
| """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_1d_nb` | |
| and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.random_ohlc") | |
| @classmethod | |
| def generate_symbol( | |
| cls, | |
| symbol: tp.Symbol, | |
| index: tp.Index, | |
| n_ticks: tp.Optional[tp.ArrayLike] = None, | |
| start_value: tp.Optional[float] = None, | |
| mean: tp.Optional[float] = None, | |
| std: tp.Optional[float] = None, | |
| symmetric: tp.Optional[bool] = None, | |
| seed: tp.Optional[int] = None, | |
| jitted: tp.JittedOption = None, | |
| template_context: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> tp.SymbolData: | |
| """Generate a symbol. | |
| Args: | |
| symbol (hashable): Symbol. | |
| index (pd.Index): Pandas index. | |
| n_ticks (int or array_like): Number of ticks per bar. | |
| Flexible argument. Can be a template with a context containing `symbol` and `index`. | |
| start_value (float): Value at time 0. | |
| Does not appear as the first value in the output data. | |
| mean (float): Drift, or mean of the percentage change. | |
| std (float): Standard deviation of the percentage change. | |
| symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones. | |
| seed (int): Seed to make output deterministic. | |
| jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. | |
| template_context (dict): Template context. | |
| For defaults, see `custom.random_ohlc` in `vectorbtpro._settings.data`. | |
| !!! note | |
| When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`. | |
| """ | |
| n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks") | |
| template_context = merge_dicts(dict(symbol=symbol, index=index), template_context) | |
| n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks") | |
| n_ticks = broadcast_array_to(n_ticks, len(index)) | |
| start_value = cls.resolve_custom_setting(start_value, "start_value") | |
| mean = cls.resolve_custom_setting(mean, "mean") | |
| std = cls.resolve_custom_setting(std, "std") | |
| symmetric = cls.resolve_custom_setting(symmetric, "symmetric") | |
| seed = cls.resolve_custom_setting(seed, "seed") | |
| if seed is not None: | |
| set_seed(seed) | |
| func = jit_reg.resolve_option(nb.generate_random_data_1d_nb, jitted) | |
| ticks = func(np.sum(n_ticks), start_value=start_value, mean=mean, std=std, symmetric=symmetric) | |
| func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted) | |
| out = func(ticks, n_ticks) | |
| return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"]) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| _ = fetch_kwargs.pop("start_value", None) | |
| start_value = self.data[symbol]["Open"].iloc[-1] | |
| fetch_kwargs["seed"] = None | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, start_value=start_value, **kwargs) | |
| </file> | |
| <file path="data/custom/random.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `RandomData`.""" | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.reshaping import to_1d_array | |
| from vectorbtpro.data import nb | |
| from vectorbtpro.data.custom.synthetic import SyntheticData | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| from vectorbtpro.utils import checks | |
| from vectorbtpro.utils.config import merge_dicts | |
| from vectorbtpro.utils.random_ import set_seed | |
| __all__ = [ | |
| "RandomData", | |
| ] | |
| __pdoc__ = {} | |
| class RandomData(SyntheticData): | |
| """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_nb`.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.random") | |
| @classmethod | |
| def generate_key( | |
| cls, | |
| key: tp.Key, | |
| index: tp.Index, | |
| columns: tp.Union[tp.Hashable, tp.IndexLike] = None, | |
| start_value: tp.Optional[float] = None, | |
| mean: tp.Optional[float] = None, | |
| std: tp.Optional[float] = None, | |
| symmetric: tp.Optional[bool] = None, | |
| seed: tp.Optional[int] = None, | |
| jitted: tp.JittedOption = None, | |
| **kwargs, | |
| ) -> tp.KeyData: | |
| """Generate a feature or symbol. | |
| Args: | |
| key (hashable): Feature or symbol. | |
| index (pd.Index): Pandas index. | |
| columns (hashable or index_like): Column names. | |
| Provide a single value (hashable) to make a Series. | |
| start_value (float): Value at time 0. | |
| Does not appear as the first value in the output data. | |
| mean (float): Drift, or mean of the percentage change. | |
| std (float): Standard deviation of the percentage change. | |
| symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones. | |
| seed (int): Seed to make output deterministic. | |
| jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. | |
| For defaults, see `custom.random` in `vectorbtpro._settings.data`. | |
| !!! note | |
| When setting a seed, remember to pass a seed per feature/symbol using | |
| `vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally | |
| `vectorbtpro.data.base.key_dict`. | |
| """ | |
| if checks.is_hashable(columns): | |
| columns = [columns] | |
| make_series = True | |
| else: | |
| make_series = False | |
| if not isinstance(columns, pd.Index): | |
| columns = pd.Index(columns) | |
| start_value = cls.resolve_custom_setting(start_value, "start_value") | |
| mean = cls.resolve_custom_setting(mean, "mean") | |
| std = cls.resolve_custom_setting(std, "std") | |
| symmetric = cls.resolve_custom_setting(symmetric, "symmetric") | |
| seed = cls.resolve_custom_setting(seed, "seed") | |
| if seed is not None: | |
| set_seed(seed) | |
| func = jit_reg.resolve_option(nb.generate_random_data_nb, jitted) | |
| out = func( | |
| (len(index), len(columns)), | |
| start_value=to_1d_array(start_value), | |
| mean=to_1d_array(mean), | |
| std=to_1d_array(std), | |
| symmetric=to_1d_array(symmetric), | |
| ) | |
| if make_series: | |
| return pd.Series(out[:, 0], index=index, name=columns[0]) | |
| return pd.DataFrame(out, index=index, columns=columns) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| fetch_kwargs["start"] = self.select_last_index(key) | |
| _ = fetch_kwargs.pop("start_value", None) | |
| start_value = self.data[key].iloc[-2] | |
| fetch_kwargs["seed"] = None | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, start_value=start_value, **kwargs) | |
| return self.fetch_symbol(key, start_value=start_value, **kwargs) | |
| </file> | |
| <file path="data/custom/remote.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `RemoteData`.""" | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.custom import CustomData | |
| __all__ = [ | |
| "RemoteData", | |
| ] | |
| __pdoc__ = {} | |
| class RemoteData(CustomData): | |
| """Data class for fetching remote data.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.remote") | |
| </file> | |
| <file path="data/custom/sql.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `SQLData`.""" | |
| from typing import Iterator | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.db import DBData | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from sqlalchemy import Engine as EngineT, Selectable as SelectableT, Table as TableT | |
| except ImportError: | |
| EngineT = "Engine" | |
| SelectableT = "Selectable" | |
| TableT = "Table" | |
| __all__ = [ | |
| "SQLData", | |
| ] | |
| __pdoc__ = {} | |
| SQLDataT = tp.TypeVar("SQLDataT", bound="SQLData") | |
| class SQLData(DBData): | |
| """Data class for fetching data from a database using SQLAlchemy. | |
| See https://www.sqlalchemy.org/ for the SQLAlchemy's API. | |
| See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for the read method. | |
| See `SQLData.pull` and `SQLData.fetch_key` for arguments. | |
| Usage: | |
| * Set up the engine settings globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.SQLData.set_engine_settings( | |
| ... engine_name="postgresql", | |
| ... populate_=True, | |
| ... engine="postgresql+psycopg2://...", | |
| ... engine_config=dict(), | |
| ... schema="public" | |
| ... ) | |
| ``` | |
| * Pull tables: | |
| ```pycon | |
| >>> data = vbt.SQLData.pull( | |
| ... ["TABLE1", "TABLE2"], | |
| ... engine="postgresql", | |
| ... start="2020-01-01", | |
| ... end="2021-01-01" | |
| ... ) | |
| ``` | |
| * Pull queries: | |
| ```pycon | |
| >>> data = vbt.SQLData.pull( | |
| ... ["SYMBOL1", "SYMBOL2"], | |
| ... query=vbt.key_dict({ | |
| ... "SYMBOL1": "SELECT * FROM TABLE1", | |
| ... "SYMBOL2": "SELECT * FROM TABLE2" | |
| ... }), | |
| ... engine="postgresql" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.sql") | |
| @classmethod | |
| def get_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> dict: | |
| """`SQLData.get_custom_settings` with `sub_path=engine_name`.""" | |
| if engine_name is not None: | |
| sub_path = "engines." + engine_name | |
| else: | |
| sub_path = None | |
| return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def has_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool: | |
| """`SQLData.has_custom_settings` with `sub_path=engine_name`.""" | |
| if engine_name is not None: | |
| sub_path = "engines." + engine_name | |
| else: | |
| sub_path = None | |
| return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def get_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any: | |
| """`SQLData.get_custom_setting` with `sub_path=engine_name`.""" | |
| if engine_name is not None: | |
| sub_path = "engines." + engine_name | |
| else: | |
| sub_path = None | |
| return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def has_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool: | |
| """`SQLData.has_custom_setting` with `sub_path=engine_name`.""" | |
| if engine_name is not None: | |
| sub_path = "engines." + engine_name | |
| else: | |
| sub_path = None | |
| return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def resolve_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any: | |
| """`SQLData.resolve_custom_setting` with `sub_path=engine_name`.""" | |
| if engine_name is not None: | |
| sub_path = "engines." + engine_name | |
| else: | |
| sub_path = None | |
| return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def set_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> None: | |
| """`SQLData.set_custom_settings` with `sub_path=engine_name`.""" | |
| if engine_name is not None: | |
| sub_path = "engines." + engine_name | |
| else: | |
| sub_path = None | |
| cls.set_custom_settings(*args, sub_path=sub_path, **kwargs) | |
| @classmethod | |
| def resolve_engine( | |
| cls, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| return_meta: bool = False, | |
| **engine_config, | |
| ) -> tp.Union[EngineT, dict]: | |
| """Resolve the engine. | |
| Argument `engine` can be | |
| 1) an object of the type `sqlalchemy.engine.base.Engine`, | |
| 2) a URL of the engine as a string, which will be used to create an engine with | |
| `sqlalchemy.engine.create.create_engine` and `engine_config` passed as keyword arguments | |
| (you should not include `url` in the `engine_config`), or | |
| 3) an engine name, which is the name of a sub-config with engine settings under `custom.sql.engines` | |
| in `vectorbtpro._settings.data`. Such a sub-config can then contain the actual engine as an object or a URL. | |
| Argument `engine_name` can be provided instead of `engine`, or also together with `engine` | |
| to pull other settings from a sub-config. URLs can also be used as engine names, but not the | |
| other way around.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import create_engine | |
| if engine is None and engine_name is None: | |
| engine_name = cls.resolve_engine_setting(engine_name, "engine_name") | |
| if engine_name is not None: | |
| engine = cls.resolve_engine_setting(engine, "engine", engine_name=engine_name) | |
| if engine is None: | |
| raise ValueError("Must provide engine or URL (via engine argument)") | |
| else: | |
| engine = cls.resolve_engine_setting(engine, "engine") | |
| if engine is None: | |
| raise ValueError("Must provide engine or URL (via engine argument)") | |
| if isinstance(engine, str): | |
| engine_name = engine | |
| else: | |
| engine_name = None | |
| if engine_name is not None: | |
| if cls.has_engine_setting("engine", engine_name=engine_name, sub_path_only=True): | |
| engine = cls.get_engine_setting("engine", engine_name=engine_name, sub_path_only=True) | |
| has_engine_config = len(engine_config) > 0 | |
| engine_config = cls.resolve_engine_setting(engine_config, "engine_config", merge=True, engine_name=engine_name) | |
| if isinstance(engine, str): | |
| if engine.startswith("duckdb:"): | |
| assert_can_import("duckdb_engine") | |
| engine = create_engine(engine, **engine_config) | |
| should_dispose = True | |
| else: | |
| if has_engine_config: | |
| raise ValueError("Cannot apply engine_config to initialized created engine") | |
| should_dispose = False | |
| if return_meta: | |
| return dict( | |
| engine=engine, | |
| engine_name=engine_name, | |
| should_dispose=should_dispose, | |
| ) | |
| return engine | |
| @classmethod | |
| def list_schemas( | |
| cls, | |
| pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| dispose_engine: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> tp.List[str]: | |
| """List all schemas. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`. | |
| Keyword arguments `**kwargs` are passed to `inspector.get_schema_names`. | |
| If `dispose_engine` is None, disposes the engine if it wasn't provided.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import inspect | |
| if engine_config is None: | |
| engine_config = {} | |
| engine_meta = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| return_meta=True, | |
| **engine_config, | |
| ) | |
| engine = engine_meta["engine"] | |
| should_dispose = engine_meta["should_dispose"] | |
| if dispose_engine is None: | |
| dispose_engine = should_dispose | |
| inspector = inspect(engine) | |
| all_schemas = inspector.get_schema_names(**kwargs) | |
| schemas = [] | |
| for schema in all_schemas: | |
| if pattern is not None: | |
| if not cls.key_match(schema, pattern, use_regex=use_regex): | |
| continue | |
| if schema == "information_schema": | |
| continue | |
| schemas.append(schema) | |
| if dispose_engine: | |
| engine.dispose() | |
| if sort: | |
| return sorted(dict.fromkeys(schemas)) | |
| return list(dict.fromkeys(schemas)) | |
| @classmethod | |
| def list_tables( | |
| cls, | |
| *, | |
| schema_pattern: tp.Optional[str] = None, | |
| table_pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| schema: tp.Optional[str] = None, | |
| incl_views: bool = True, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| dispose_engine: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> tp.List[str]: | |
| """List all tables and views. | |
| If `schema` is None, searches for all schema names in the database and prefixes each table | |
| with the respective schema name (unless there's only one schema "main"). If `schema` is False, | |
| sets the schema to None. If `schema` is provided, returns the tables corresponding to this | |
| schema without a prefix. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against | |
| `schema_pattern` and each table against `table_pattern`. | |
| Keyword arguments `**kwargs` are passed to `inspector.get_table_names`. | |
| If `dispose_engine` is None, disposes the engine if it wasn't provided.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import inspect | |
| if engine_config is None: | |
| engine_config = {} | |
| engine_meta = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| return_meta=True, | |
| **engine_config, | |
| ) | |
| engine = engine_meta["engine"] | |
| engine_name = engine_meta["engine_name"] | |
| should_dispose = engine_meta["should_dispose"] | |
| if dispose_engine is None: | |
| dispose_engine = should_dispose | |
| schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name) | |
| if schema is None: | |
| schemas = cls.list_schemas( | |
| pattern=schema_pattern, | |
| use_regex=use_regex, | |
| sort=sort, | |
| engine=engine, | |
| engine_name=engine_name, | |
| **kwargs, | |
| ) | |
| if len(schemas) == 0: | |
| schemas = [None] | |
| prefix_schema = False | |
| elif len(schemas) == 1 and schemas[0] == "main": | |
| prefix_schema = False | |
| else: | |
| prefix_schema = True | |
| elif schema is False: | |
| schemas = [None] | |
| prefix_schema = False | |
| else: | |
| schemas = [schema] | |
| prefix_schema = False | |
| inspector = inspect(engine) | |
| tables = [] | |
| for schema in schemas: | |
| all_tables = inspector.get_table_names(schema, **kwargs) | |
| if incl_views: | |
| try: | |
| all_tables += inspector.get_view_names(schema, **kwargs) | |
| except NotImplementedError as e: | |
| pass | |
| try: | |
| all_tables += inspector.get_materialized_view_names(schema, **kwargs) | |
| except NotImplementedError as e: | |
| pass | |
| for table in all_tables: | |
| if table_pattern is not None: | |
| if not cls.key_match(table, table_pattern, use_regex=use_regex): | |
| continue | |
| if prefix_schema and schema is not None: | |
| table = str(schema) + ":" + table | |
| tables.append(table) | |
| if dispose_engine: | |
| engine.dispose() | |
| if sort: | |
| return sorted(dict.fromkeys(tables)) | |
| return list(dict.fromkeys(tables)) | |
| @classmethod | |
| def has_schema( | |
| cls, | |
| schema: str, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| ) -> bool: | |
| """Check whether the database has a schema.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import inspect | |
| if engine_config is None: | |
| engine_config = {} | |
| engine = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| **engine_config, | |
| ) | |
| return inspect(engine).has_schema(schema) | |
| @classmethod | |
| def create_schema( | |
| cls, | |
| schema: str, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| ) -> None: | |
| """Create a schema if it doesn't exist yet.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy.schema import CreateSchema | |
| if engine_config is None: | |
| engine_config = {} | |
| engine = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| **engine_config, | |
| ) | |
| if not cls.has_schema(schema, engine=engine, engine_name=engine_name): | |
| with engine.connect() as connection: | |
| connection.execute(CreateSchema(schema)) | |
| connection.commit() | |
| @classmethod | |
| def has_table( | |
| cls, | |
| table: str, | |
| schema: tp.Optional[str] = None, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| ) -> bool: | |
| """Check whether the database has a table.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import inspect | |
| if engine_config is None: | |
| engine_config = {} | |
| engine = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| **engine_config, | |
| ) | |
| return inspect(engine).has_table(table, schema=schema) | |
| @classmethod | |
| def get_table_relation( | |
| cls, | |
| table: str, | |
| schema: tp.Optional[str] = None, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| ) -> TableT: | |
| """Get table relation.""" | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import MetaData | |
| if engine_config is None: | |
| engine_config = {} | |
| engine = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| **engine_config, | |
| ) | |
| schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name) | |
| metadata_obj = MetaData() | |
| metadata_obj.reflect(bind=engine, schema=schema, only=[table], views=True) | |
| if schema is not None and schema + "." + table in metadata_obj.tables: | |
| return metadata_obj.tables[schema + "." + table] | |
| return metadata_obj.tables[table] | |
| @classmethod | |
| def get_last_row_number( | |
| cls, | |
| table: str, | |
| schema: tp.Optional[str] = None, | |
| row_number_column: tp.Optional[str] = None, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| ) -> TableT: | |
| """Get last row number.""" | |
| if engine_config is None: | |
| engine_config = {} | |
| engine_meta = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| return_meta=True, | |
| **engine_config, | |
| ) | |
| engine = engine_meta["engine"] | |
| engine_name = engine_meta["engine_name"] | |
| row_number_column = cls.resolve_engine_setting( | |
| row_number_column, | |
| "row_number_column", | |
| engine_name=engine_name, | |
| ) | |
| table_relation = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name) | |
| table_column_names = [] | |
| for column in table_relation.columns: | |
| table_column_names.append(column.name) | |
| if row_number_column not in table_column_names: | |
| raise ValueError(f"Row number column '{row_number_column}' not found") | |
| query = ( | |
| table_relation.select() | |
| .with_only_columns(table_relation.columns.get(row_number_column)) | |
| .order_by(table_relation.columns.get(row_number_column).desc()) | |
| .limit(1) | |
| ) | |
| with engine.connect() as connection: | |
| results = connection.execute(query) | |
| last_row_number = results.first()[0] | |
| connection.commit() | |
| return last_row_number | |
| @classmethod | |
| def resolve_keys_meta( | |
| cls, | |
| keys: tp.Union[None, dict, tp.MaybeKeys] = None, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[None, dict, tp.MaybeFeatures] = None, | |
| symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, | |
| schema: tp.Optional[str] = None, | |
| list_tables_kwargs: tp.KwargsLike = None, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| ) -> tp.Kwargs: | |
| keys_meta = DBData.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| ) | |
| if keys_meta["keys"] is None: | |
| if cls.has_key_dict(schema): | |
| raise ValueError("Cannot populate keys if schema is defined per key") | |
| if cls.has_key_dict(list_tables_kwargs): | |
| raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key") | |
| if cls.has_key_dict(engine): | |
| raise ValueError("Cannot populate keys if engine is defined per key") | |
| if cls.has_key_dict(engine_name): | |
| raise ValueError("Cannot populate keys if engine_name is defined per key") | |
| if cls.has_key_dict(engine_config): | |
| raise ValueError("Cannot populate keys if engine_config is defined per key") | |
| if list_tables_kwargs is None: | |
| list_tables_kwargs = {} | |
| keys_meta["keys"] = cls.list_tables( | |
| schema=schema, | |
| engine=engine, | |
| engine_name=engine_name, | |
| engine_config=engine_config, | |
| **list_tables_kwargs, | |
| ) | |
| return keys_meta | |
| @classmethod | |
| def pull( | |
| cls: tp.Type[SQLDataT], | |
| keys: tp.Union[tp.MaybeKeys] = None, | |
| *, | |
| keys_are_features: tp.Optional[bool] = None, | |
| features: tp.Union[tp.MaybeFeatures] = None, | |
| symbols: tp.Union[tp.MaybeSymbols] = None, | |
| schema: tp.Optional[str] = None, | |
| list_tables_kwargs: tp.KwargsLike = None, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| dispose_engine: tp.Optional[bool] = None, | |
| share_engine: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> SQLDataT: | |
| """Override `vectorbtpro.data.base.Data.pull` to resolve and share the engine among the keys | |
| and use the table names available in the database in case no keys were provided.""" | |
| if share_engine is None: | |
| if ( | |
| not cls.has_key_dict(engine) | |
| and not cls.has_key_dict(engine_name) | |
| and not cls.has_key_dict(engine_config) | |
| ): | |
| share_engine = True | |
| else: | |
| share_engine = False | |
| if share_engine: | |
| if engine_config is None: | |
| engine_config = {} | |
| engine_meta = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| return_meta=True, | |
| **engine_config, | |
| ) | |
| engine = engine_meta["engine"] | |
| engine_name = engine_meta["engine_name"] | |
| should_dispose = engine_meta["should_dispose"] | |
| if dispose_engine is None: | |
| dispose_engine = should_dispose | |
| else: | |
| engine_name = None | |
| keys_meta = cls.resolve_keys_meta( | |
| keys=keys, | |
| keys_are_features=keys_are_features, | |
| features=features, | |
| symbols=symbols, | |
| schema=schema, | |
| list_tables_kwargs=list_tables_kwargs, | |
| engine=engine, | |
| engine_name=engine_name, | |
| engine_config=engine_config, | |
| ) | |
| keys = keys_meta["keys"] | |
| keys_are_features = keys_meta["keys_are_features"] | |
| outputs = super(DBData, cls).pull( | |
| keys, | |
| keys_are_features=keys_are_features, | |
| schema=schema, | |
| engine=engine, | |
| engine_name=engine_name, | |
| engine_config=engine_config, | |
| dispose_engine=False if share_engine else dispose_engine, | |
| **kwargs, | |
| ) | |
| if share_engine and dispose_engine: | |
| engine.dispose() | |
| return outputs | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: str, | |
| table: tp.Union[None, str, TableT] = None, | |
| schema: tp.Optional[str] = None, | |
| query: tp.Union[None, str, SelectableT] = None, | |
| engine: tp.Union[None, str, EngineT] = None, | |
| engine_name: tp.Optional[str] = None, | |
| engine_config: tp.KwargsLike = None, | |
| dispose_engine: tp.Optional[bool] = None, | |
| start: tp.Optional[tp.Any] = None, | |
| end: tp.Optional[tp.Any] = None, | |
| align_dates: tp.Optional[bool] = None, | |
| parse_dates: tp.Union[None, bool, tp.List[tp.IntStr], tp.Dict[tp.IntStr, tp.Any]] = None, | |
| to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None, | |
| tz: tp.TimezoneLike = None, | |
| start_row: tp.Optional[int] = None, | |
| end_row: tp.Optional[int] = None, | |
| keep_row_number: tp.Optional[bool] = None, | |
| row_number_column: tp.Optional[str] = None, | |
| index_col: tp.Union[None, bool, tp.MaybeList[tp.IntStr]] = None, | |
| columns: tp.Optional[tp.MaybeList[tp.IntStr]] = None, | |
| dtype: tp.Union[None, tp.DTypeLike, tp.Dict[tp.IntStr, tp.DTypeLike]] = None, | |
| chunksize: tp.Optional[int] = None, | |
| chunk_func: tp.Optional[tp.Callable] = None, | |
| squeeze: tp.Optional[bool] = None, | |
| **read_sql_kwargs, | |
| ) -> tp.KeyData: | |
| """Fetch a feature or symbol from a SQL database. | |
| Can use a table name (which defaults to the key) or a custom query. | |
| Args: | |
| key (str): Feature or symbol. | |
| If `table` and `query` are both None, becomes the table name. | |
| Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored. | |
| table (str or Table): Table name or actual object. | |
| Cannot be used together with `query`. | |
| schema (str): Schema. | |
| Cannot be used together with `query`. | |
| query (str or Selectable): Custom query. | |
| Cannot be used together with `table` and `schema`. | |
| engine (str or object): See `SQLData.resolve_engine`. | |
| engine_name (str): See `SQLData.resolve_engine`. | |
| engine_config (dict): See `SQLData.resolve_engine`. | |
| dispose_engine (bool): See `SQLData.resolve_engine`. | |
| start (any): Start datetime (if datetime index) or any other start value. | |
| Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True | |
| and the index is a datetime index. Otherwise, you must ensure the correct type is provided. | |
| If the index is a multi-index, start value must be a tuple. | |
| Cannot be used together with `query`. Include the condition into the query. | |
| end (any): End datetime (if datetime index) or any other end value. | |
| Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True | |
| and the index is a datetime index. Otherwise, you must ensure the correct type is provided. | |
| If the index is a multi-index, end value must be a tuple. | |
| Cannot be used together with `query`. Include the condition into the query. | |
| align_dates (bool): Whether to align `start` and `end` to the timezone of the index. | |
| Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index. | |
| parse_dates (bool, list, or dict): Whether to parse dates and how to do it. | |
| If `query` is not used, will get mapped into column names. Otherwise, | |
| usage of integers is not allowed and column names directly must be used. | |
| If enabled, will also try to parse the datetime columns that couldn't be parsed | |
| by Pandas after the object has been fetched. | |
| For dict format, see `pd.read_sql_query`. | |
| to_utc (bool, str, or sequence of str): See `SQLData.prepare_dt`. | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| start_row (int): Start row. | |
| Table must contain the column defined in `row_number_column`. | |
| Cannot be used together with `query`. Include the condition into the query. | |
| end_row (int): End row. | |
| Table must contain the column defined in `row_number_column`. | |
| Cannot be used together with `query`. Include the condition into the query. | |
| keep_row_number (bool): Whether to return the column defined in `row_number_column`. | |
| row_number_column (str): Name of the column with row numbers. | |
| index_col (int, str, or list): One or more columns that should become the index. | |
| If `query` is not used, will get mapped into column names. Otherwise, | |
| usage of integers is not allowed and column names directly must be used. | |
| columns (int, str, or list): One or more columns to select. | |
| Will get mapped into column names. Cannot be used together with `query`. | |
| dtype (dtype_like or dict): Data type of each column. | |
| If `query` is not used, will get mapped into column names. Otherwise, | |
| usage of integers is not allowed and column names directly must be used. | |
| For dict format, see `pd.read_sql_query`. | |
| chunksize (int): See `pd.read_sql_query`. | |
| chunk_func (callable): Function to select and concatenate chunks from `Iterator`. | |
| Gets called only if `chunksize` is set. | |
| squeeze (int): Whether to squeeze a DataFrame with one column into a Series. | |
| **read_sql_kwargs: Other keyword arguments passed to `pd.read_sql_query`. | |
| See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for other arguments. | |
| For defaults, see `custom.sql` in `vectorbtpro._settings.data`. | |
| Global settings can be provided per engine name using the `engines` dictionary. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("sqlalchemy") | |
| from sqlalchemy import Selectable, Select, FromClause, and_, text | |
| if engine_config is None: | |
| engine_config = {} | |
| engine_meta = cls.resolve_engine( | |
| engine=engine, | |
| engine_name=engine_name, | |
| return_meta=True, | |
| **engine_config, | |
| ) | |
| engine = engine_meta["engine"] | |
| engine_name = engine_meta["engine_name"] | |
| should_dispose = engine_meta["should_dispose"] | |
| if dispose_engine is None: | |
| dispose_engine = should_dispose | |
| if table is not None and query is not None: | |
| raise ValueError("Must provide either table name or query, not both") | |
| if schema is not None and query is not None: | |
| raise ValueError("Schema cannot be applied to custom queries") | |
| if table is None and query is None: | |
| if ":" in key: | |
| schema, table = key.split(":") | |
| else: | |
| table = key | |
| start = cls.resolve_engine_setting(start, "start", engine_name=engine_name) | |
| end = cls.resolve_engine_setting(end, "end", engine_name=engine_name) | |
| align_dates = cls.resolve_engine_setting(align_dates, "align_dates", engine_name=engine_name) | |
| parse_dates = cls.resolve_engine_setting(parse_dates, "parse_dates", engine_name=engine_name) | |
| to_utc = cls.resolve_engine_setting(to_utc, "to_utc", engine_name=engine_name) | |
| tz = cls.resolve_engine_setting(tz, "tz", engine_name=engine_name) | |
| start_row = cls.resolve_engine_setting(start_row, "start_row", engine_name=engine_name) | |
| end_row = cls.resolve_engine_setting(end_row, "end_row", engine_name=engine_name) | |
| keep_row_number = cls.resolve_engine_setting(keep_row_number, "keep_row_number", engine_name=engine_name) | |
| row_number_column = cls.resolve_engine_setting(row_number_column, "row_number_column", engine_name=engine_name) | |
| index_col = cls.resolve_engine_setting(index_col, "index_col", engine_name=engine_name) | |
| columns = cls.resolve_engine_setting(columns, "columns", engine_name=engine_name) | |
| dtype = cls.resolve_engine_setting(dtype, "dtype", engine_name=engine_name) | |
| chunksize = cls.resolve_engine_setting(chunksize, "chunksize", engine_name=engine_name) | |
| chunk_func = cls.resolve_engine_setting(chunk_func, "chunk_func", engine_name=engine_name) | |
| squeeze = cls.resolve_engine_setting(squeeze, "squeeze", engine_name=engine_name) | |
| read_sql_kwargs = cls.resolve_engine_setting( | |
| read_sql_kwargs, "read_sql_kwargs", merge=True, engine_name=engine_name | |
| ) | |
| if query is None or isinstance(query, (Selectable, FromClause)): | |
| if query is None: | |
| if isinstance(table, str): | |
| table = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name) | |
| else: | |
| table = query | |
| table_column_names = [] | |
| for column in table.columns: | |
| table_column_names.append(column.name) | |
| def _resolve_columns(c): | |
| if checks.is_int(c): | |
| c = table_column_names[int(c)] | |
| elif not isinstance(c, str): | |
| new_c = [] | |
| for _c in c: | |
| if checks.is_int(_c): | |
| new_c.append(table_column_names[int(_c)]) | |
| else: | |
| if _c not in table_column_names: | |
| for __c in table_column_names: | |
| if _c.lower() == __c.lower(): | |
| _c = __c | |
| break | |
| new_c.append(_c) | |
| c = new_c | |
| else: | |
| if c not in table_column_names: | |
| for _c in table_column_names: | |
| if c.lower() == _c.lower(): | |
| return _c | |
| return c | |
| if index_col is False: | |
| index_col = None | |
| if index_col is not None: | |
| index_col = _resolve_columns(index_col) | |
| if isinstance(index_col, str): | |
| index_col = [index_col] | |
| if columns is not None: | |
| columns = _resolve_columns(columns) | |
| if isinstance(columns, str): | |
| columns = [columns] | |
| if parse_dates is not None: | |
| if not isinstance(parse_dates, bool): | |
| if isinstance(parse_dates, dict): | |
| parse_dates = dict(zip(_resolve_columns(parse_dates.keys()), parse_dates.values())) | |
| else: | |
| parse_dates = _resolve_columns(parse_dates) | |
| if isinstance(parse_dates, str): | |
| parse_dates = [parse_dates] | |
| if dtype is not None: | |
| if isinstance(dtype, dict): | |
| dtype = dict(zip(_resolve_columns(dtype.keys()), dtype.values())) | |
| if not isinstance(table, Select): | |
| query = table.select() | |
| else: | |
| query = table | |
| if index_col is not None: | |
| for col in index_col: | |
| query = query.order_by(col) | |
| if index_col is not None and columns is not None: | |
| pre_columns = [] | |
| for col in index_col: | |
| if col not in columns: | |
| pre_columns.append(col) | |
| columns = pre_columns + columns | |
| if keep_row_number and columns is not None: | |
| if row_number_column in table_column_names and row_number_column not in columns: | |
| columns = [row_number_column] + columns | |
| elif not keep_row_number and columns is None: | |
| if row_number_column in table_column_names: | |
| columns = [col for col in table_column_names if col != row_number_column] | |
| if columns is not None: | |
| query = query.with_only_columns(*[table.columns.get(c) for c in columns]) | |
| def _to_native_type(x): | |
| if checks.is_np_scalar(x): | |
| return x.item() | |
| return x | |
| if start_row is not None or end_row is not None: | |
| if start is not None or end is not None: | |
| raise ValueError("Can either filter by row numbers or by index, not both") | |
| _row_number_column = table.columns.get(row_number_column) | |
| if _row_number_column is None: | |
| raise ValueError(f"Row number column '{row_number_column}' not found") | |
| and_list = [] | |
| if start_row is not None: | |
| and_list.append(_row_number_column >= _to_native_type(start_row)) | |
| if end_row is not None: | |
| and_list.append(_row_number_column < _to_native_type(end_row)) | |
| query = query.where(and_(*and_list)) | |
| if start is not None or end is not None: | |
| if index_col is None: | |
| raise ValueError("Must provide index column for filtering by start and end") | |
| if align_dates: | |
| first_obj = pd.read_sql_query( | |
| query.limit(1), | |
| engine.connect(), | |
| index_col=index_col, | |
| parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted | |
| dtype=dtype, | |
| chunksize=None, | |
| **read_sql_kwargs, | |
| ) | |
| first_obj = cls.prepare_dt( | |
| first_obj, | |
| parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates, | |
| to_utc=False, | |
| ) | |
| if isinstance(first_obj.index, pd.DatetimeIndex): | |
| if tz is None: | |
| tz = first_obj.index.tz | |
| if first_obj.index.tz is not None: | |
| if start is not None: | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=first_obj.index.tz) | |
| if end is not None: | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=first_obj.index.tz) | |
| else: | |
| if start is not None: | |
| if ( | |
| to_utc is True | |
| or (isinstance(to_utc, str) and to_utc.lower() == "index") | |
| or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc) | |
| ): | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") | |
| start = dt.to_naive_datetime(start) | |
| else: | |
| start = dt.to_naive_datetime(start, tz=tz) | |
| if end is not None: | |
| if ( | |
| to_utc is True | |
| or (isinstance(to_utc, str) and to_utc.lower() == "index") | |
| or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc) | |
| ): | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") | |
| end = dt.to_naive_datetime(end) | |
| else: | |
| end = dt.to_naive_datetime(end, tz=tz) | |
| and_list = [] | |
| if start is not None: | |
| if len(index_col) > 1: | |
| if not isinstance(start, tuple): | |
| raise TypeError("Start must be a tuple if the index is a multi-index") | |
| if len(start) != len(index_col): | |
| raise ValueError("Start tuple must match the number of levels in the multi-index") | |
| for i in range(len(index_col)): | |
| index_column = table.columns.get(index_col[i]) | |
| and_list.append(index_column >= _to_native_type(start[i])) | |
| else: | |
| index_column = table.columns.get(index_col[0]) | |
| and_list.append(index_column >= _to_native_type(start)) | |
| if end is not None: | |
| if len(index_col) > 1: | |
| if not isinstance(end, tuple): | |
| raise TypeError("End must be a tuple if the index is a multi-index") | |
| if len(end) != len(index_col): | |
| raise ValueError("End tuple must match the number of levels in the multi-index") | |
| for i in range(len(index_col)): | |
| index_column = table.columns.get(index_col[i]) | |
| and_list.append(index_column < _to_native_type(end[i])) | |
| else: | |
| index_column = table.columns.get(index_col[0]) | |
| and_list.append(index_column < _to_native_type(end)) | |
| query = query.where(and_(*and_list)) | |
| else: | |
| def _check_columns(c, arg_name): | |
| if checks.is_int(c): | |
| raise ValueError(f"Must provide column as a string for '{arg_name}'") | |
| elif not isinstance(c, str): | |
| for _c in c: | |
| if checks.is_int(_c): | |
| raise ValueError(f"Must provide each column as a string for '{arg_name}'") | |
| if start is not None: | |
| raise ValueError("Start cannot be applied to custom queries") | |
| if end is not None: | |
| raise ValueError("End cannot be applied to custom queries") | |
| if start_row is not None: | |
| raise ValueError("Start row cannot be applied to custom queries") | |
| if end_row is not None: | |
| raise ValueError("End row cannot be applied to custom queries") | |
| if index_col is False: | |
| index_col = None | |
| if index_col is not None: | |
| _check_columns(index_col, "index_col") | |
| if isinstance(index_col, str): | |
| index_col = [index_col] | |
| if columns is not None: | |
| raise ValueError("Columns cannot be applied to custom queries") | |
| if parse_dates is not None: | |
| if not isinstance(parse_dates, bool): | |
| if isinstance(parse_dates, dict): | |
| _check_columns(parse_dates.keys(), "parse_dates") | |
| else: | |
| _check_columns(parse_dates, "parse_dates") | |
| if isinstance(parse_dates, str): | |
| parse_dates = [parse_dates] | |
| if dtype is not None: | |
| _check_columns(dtype.keys(), "dtype") | |
| if isinstance(query, str): | |
| query = text(query) | |
| obj = pd.read_sql_query( | |
| query, | |
| engine.connect(), | |
| index_col=index_col, | |
| parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted | |
| dtype=dtype, | |
| chunksize=chunksize, | |
| **read_sql_kwargs, | |
| ) | |
| if isinstance(obj, Iterator): | |
| if chunk_func is None: | |
| obj = pd.concat(list(obj), axis=0) | |
| else: | |
| obj = chunk_func(obj) | |
| obj = cls.prepare_dt( | |
| obj, | |
| parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates, | |
| to_utc=to_utc, | |
| ) | |
| if not isinstance(obj.index, pd.MultiIndex): | |
| if obj.index.name == "index": | |
| obj.index.name = None | |
| if isinstance(obj.index, pd.DatetimeIndex) and tz is None: | |
| tz = obj.index.tz | |
| if isinstance(obj, pd.DataFrame) and squeeze: | |
| obj = obj.squeeze("columns") | |
| if isinstance(obj, pd.Series) and obj.name == "0": | |
| obj.name = None | |
| if dispose_engine: | |
| engine.dispose() | |
| if keep_row_number: | |
| return obj, dict(tz=tz, row_number_column=row_number_column) | |
| return obj, dict(tz=tz) | |
| @classmethod | |
| def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData: | |
| """Fetch the table of a feature. | |
| Uses `SQLData.fetch_key`.""" | |
| return cls.fetch_key(feature, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData: | |
| """Fetch the table for a symbol. | |
| Uses `SQLData.fetch_key`.""" | |
| return cls.fetch_key(symbol, **kwargs) | |
| def update_key( | |
| self, | |
| key: str, | |
| from_last_row: tp.Optional[bool] = None, | |
| from_last_index: tp.Optional[bool] = None, | |
| **kwargs, | |
| ) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| returned_kwargs = self.select_returned_kwargs(key) | |
| pre_kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if from_last_row is None: | |
| if pre_kwargs.get("query", None) is not None: | |
| from_last_row = False | |
| elif from_last_index is True: | |
| from_last_row = False | |
| elif pre_kwargs.get("start", None) is not None or pre_kwargs.get("end", None) is not None: | |
| from_last_row = False | |
| elif "row_number_column" not in returned_kwargs: | |
| from_last_row = False | |
| elif returned_kwargs["row_number_column"] not in self.wrapper.columns: | |
| from_last_row = False | |
| else: | |
| from_last_row = True | |
| if from_last_index is None: | |
| if pre_kwargs.get("query", None) is not None: | |
| from_last_index = False | |
| elif from_last_row is True: | |
| from_last_index = False | |
| elif pre_kwargs.get("start_row", None) is not None or pre_kwargs.get("end_row", None) is not None: | |
| from_last_index = False | |
| else: | |
| from_last_index = True | |
| if from_last_row: | |
| if "row_number_column" not in returned_kwargs: | |
| raise ValueError("Argument row_number_column must be in returned_kwargs for from_last_row") | |
| row_number_column = returned_kwargs["row_number_column"] | |
| fetch_kwargs["start_row"] = self.data[key][row_number_column].iloc[-1] | |
| if from_last_index: | |
| fetch_kwargs["start"] = self.select_last_index(key) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if self.feature_oriented: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: str, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `SQLData.update_key`.""" | |
| return self.update_key(feature, **kwargs) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `SQLData.update_key`.""" | |
| return self.update_key(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/synthetic.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `SyntheticData`.""" | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.custom import CustomData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts | |
| __all__ = [ | |
| "SyntheticData", | |
| ] | |
| __pdoc__ = {} | |
| class SyntheticData(CustomData): | |
| """Data class for fetching synthetic data. | |
| Exposes an abstract class method `SyntheticData.generate_symbol`. | |
| Everything else is taken care of.""" | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.synthetic") | |
| @classmethod | |
| def generate_key(cls, key: tp.Key, index: tp.Index, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| """Abstract method to generate data of a feature or symbol.""" | |
| raise NotImplementedError | |
| @classmethod | |
| def generate_feature(cls, feature: tp.Feature, index: tp.Index, **kwargs) -> tp.FeatureData: | |
| """Abstract method to generate data of a feature. | |
| Uses `SyntheticData.generate_key` with `key_is_feature=True`.""" | |
| return cls.generate_key(feature, index, key_is_feature=True, **kwargs) | |
| @classmethod | |
| def generate_symbol(cls, symbol: tp.Symbol, index: tp.Index, **kwargs) -> tp.SymbolData: | |
| """Abstract method to generate data for a symbol. | |
| Uses `SyntheticData.generate_key` with `key_is_feature=False`.""" | |
| return cls.generate_key(symbol, index, key_is_feature=False, **kwargs) | |
| @classmethod | |
| def fetch_key( | |
| cls, | |
| key: tp.Symbol, | |
| key_is_feature: bool = False, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| periods: tp.Optional[int] = None, | |
| timeframe: tp.Optional[tp.FrequencyLike] = None, | |
| tz: tp.TimezoneLike = None, | |
| normalize: tp.Optional[bool] = None, | |
| inclusive: tp.Optional[str] = None, | |
| **kwargs, | |
| ) -> tp.KeyData: | |
| """Generate data of a feature or symbol. | |
| Generates datetime index using `vectorbtpro.utils.datetime_.date_range` and passes it to | |
| `SyntheticData.generate_key` to fill the Series/DataFrame with generated data. | |
| For defaults, see `custom.synthetic` in `vectorbtpro._settings.data`.""" | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| normalize = cls.resolve_custom_setting(normalize, "normalize") | |
| inclusive = cls.resolve_custom_setting(inclusive, "inclusive") | |
| index = dt.date_range( | |
| start=start, | |
| end=end, | |
| periods=periods, | |
| freq=timeframe, | |
| normalize=normalize, | |
| inclusive=inclusive, | |
| ) | |
| if tz is None: | |
| tz = index.tz | |
| if len(index) == 0: | |
| raise ValueError("Date range is empty") | |
| if key_is_feature: | |
| return cls.generate_feature(key, index, **kwargs), dict(tz=tz, freq=timeframe) | |
| return cls.generate_symbol(key, index, **kwargs), dict(tz=tz, freq=timeframe) | |
| @classmethod | |
| def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Generate data of a feature. | |
| Uses `SyntheticData.fetch_key` with `key_is_feature=True`.""" | |
| return cls.fetch_key(feature, key_is_feature=True, **kwargs) | |
| @classmethod | |
| def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Generate data for a symbol. | |
| Uses `SyntheticData.fetch_key` with `key_is_feature=False`.""" | |
| return cls.fetch_key(symbol, key_is_feature=False, **kwargs) | |
| def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: | |
| """Update data of a feature or symbol.""" | |
| fetch_kwargs = self.select_fetch_kwargs(key) | |
| fetch_kwargs["start"] = self.select_last_index(key) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| if key_is_feature: | |
| return self.fetch_feature(key, **kwargs) | |
| return self.fetch_symbol(key, **kwargs) | |
| def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: | |
| """Update data of a feature. | |
| Uses `SyntheticData.update_key` with `key_is_feature=True`.""" | |
| return self.update_key(feature, key_is_feature=True, **kwargs) | |
| def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: | |
| """Update data for a symbol. | |
| Uses `SyntheticData.update_key` with `key_is_feature=False`.""" | |
| return self.update_key(symbol, key_is_feature=False, **kwargs) | |
| </file> | |
| <file path="data/custom/tv.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `TVData`.""" | |
| import datetime | |
| import json | |
| import math | |
| import random | |
| import re | |
| import string | |
| import time | |
| import pandas as pd | |
| import requests | |
| from websocket import WebSocket | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts, Configured | |
| from vectorbtpro.utils.pbar import ProgressBar | |
| from vectorbtpro.utils.template import CustomTemplate | |
| __all__ = [ | |
| "TVClient", | |
| "TVData", | |
| ] | |
| SIGNIN_URL = "https://www.tradingview.com/accounts/signin/" | |
| """Sign-in URL.""" | |
| SEARCH_URL = ( | |
| "https://symbol-search.tradingview.com/symbol_search/v3/?" | |
| "text={text}&" | |
| "start={start}&" | |
| "hl=1&" | |
| "exchange={exchange}&" | |
| "lang=en&" | |
| "search_type=undefined&" | |
| "domain=production&" | |
| "sort_by_country=US" | |
| ) | |
| """Symbol search URL.""" | |
| SCAN_URL = "https://scanner.tradingview.com/{market}/scan" | |
| """Market scanner URL.""" | |
| ORIGIN_URL = "https://data.tradingview.com" | |
| """Origin URL.""" | |
| REFERER_URL = "https://www.tradingview.com" | |
| """Referer URL.""" | |
| WS_URL = "wss://data.tradingview.com/socket.io/websocket" | |
| """Websocket URL.""" | |
| PRO_WS_URL = "wss://prodata.tradingview.com/socket.io/websocket" | |
| """Websocket URL (Pro).""" | |
| WS_TIMEOUT = 5 | |
| """Websocket timeout.""" | |
| MARKET_LIST = [ | |
| "america", | |
| "argentina", | |
| "australia", | |
| "austria", | |
| "bahrain", | |
| "bangladesh", | |
| "belgium", | |
| "brazil", | |
| "canada", | |
| "chile", | |
| "china", | |
| "colombia", | |
| "cyprus", | |
| "czech", | |
| "denmark", | |
| "egypt", | |
| "estonia", | |
| "euronext", | |
| "finland", | |
| "france", | |
| "germany", | |
| "greece", | |
| "hongkong", | |
| "hungary", | |
| "iceland", | |
| "india", | |
| "indonesia", | |
| "israel", | |
| "italy", | |
| "japan", | |
| "kenya", | |
| "korea", | |
| "ksa", | |
| "kuwait", | |
| "latvia", | |
| "lithuania", | |
| "luxembourg", | |
| "malaysia", | |
| "mexico", | |
| "morocco", | |
| "netherlands", | |
| "newzealand", | |
| "nigeria", | |
| "norway", | |
| "pakistan", | |
| "peru", | |
| "philippines", | |
| "poland", | |
| "portugal", | |
| "qatar", | |
| "romania", | |
| "rsa", | |
| "russia", | |
| "serbia", | |
| "singapore", | |
| "slovakia", | |
| "spain", | |
| "srilanka", | |
| "sweden", | |
| "switzerland", | |
| "taiwan", | |
| "thailand", | |
| "tunisia", | |
| "turkey", | |
| "uae", | |
| "uk", | |
| "venezuela", | |
| "vietnam", | |
| ] | |
| """List of markets supported by the market scanner (list may be incomplete).""" | |
| FIELD_LIST = [ | |
| "name", | |
| "description", | |
| "logoid", | |
| "update_mode", | |
| "type", | |
| "typespecs", | |
| "close", | |
| "pricescale", | |
| "minmov", | |
| "fractional", | |
| "minmove2", | |
| "currency", | |
| "change", | |
| "change_abs", | |
| "Recommend.All", | |
| "volume", | |
| "Value.Traded", | |
| "market_cap_basic", | |
| "fundamental_currency_code", | |
| "Perf.1Y.MarketCap", | |
| "price_earnings_ttm", | |
| "earnings_per_share_basic_ttm", | |
| "number_of_employees_fy", | |
| "sector", | |
| "market", | |
| ] | |
| """List of fields supported by the market scanner (list may be incomplete).""" | |
| USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36" | |
| """User agent.""" | |
| class TVClient(Configured): | |
| """Client for TradingView.""" | |
| def __init__( | |
| self, | |
| username: tp.Optional[str] = None, | |
| password: tp.Optional[str] = None, | |
| auth_token: tp.Optional[str] = None, | |
| **kwargs, | |
| ) -> None: | |
| """Client for TradingView.""" | |
| Configured.__init__( | |
| self, | |
| username=username, | |
| password=password, | |
| auth_token=auth_token, | |
| **kwargs, | |
| ) | |
| if auth_token is None: | |
| auth_token = self.auth(username, password) | |
| elif username is not None or password is not None: | |
| raise ValueError("Must provide either username and password, or auth_token") | |
| self._auth_token = auth_token | |
| self._ws = None | |
| self._session = self.generate_session() | |
| self._chart_session = self.generate_chart_session() | |
| @property | |
| def auth_token(self) -> str: | |
| """Authentication token.""" | |
| return self._auth_token | |
| @property | |
| def ws(self) -> WebSocket: | |
| """Instance of `websocket.Websocket`.""" | |
| return self._ws | |
| @property | |
| def session(self) -> str: | |
| """Session.""" | |
| return self._session | |
| @property | |
| def chart_session(self) -> str: | |
| """Chart session.""" | |
| return self._chart_session | |
| @classmethod | |
| def auth( | |
| cls, | |
| username: tp.Optional[str] = None, | |
| password: tp.Optional[str] = None, | |
| ) -> str: | |
| """Authenticate.""" | |
| if username is not None and password is not None: | |
| data = {"username": username, "password": password, "remember": "on"} | |
| headers = {"Referer": REFERER_URL, "User-Agent": USER_AGENT} | |
| response = requests.post(url=SIGNIN_URL, data=data, headers=headers) | |
| response.raise_for_status() | |
| json = response.json() | |
| if "user" not in json or "auth_token" not in json["user"]: | |
| raise ValueError(json) | |
| return json["user"]["auth_token"] | |
| if username is not None or password is not None: | |
| raise ValueError("Must provide both username and password") | |
| return "unauthorized_user_token" | |
| @classmethod | |
| def generate_session(cls) -> str: | |
| """Generate session.""" | |
| stringLength = 12 | |
| letters = string.ascii_lowercase | |
| random_string = "".join(random.choice(letters) for _ in range(stringLength)) | |
| return "qs_" + random_string | |
| @classmethod | |
| def generate_chart_session(cls) -> str: | |
| """Generate chart session.""" | |
| stringLength = 12 | |
| letters = string.ascii_lowercase | |
| random_string = "".join(random.choice(letters) for _ in range(stringLength)) | |
| return "cs_" + random_string | |
| def create_connection(self, pro_data: bool = True) -> None: | |
| """Create a websocket connection.""" | |
| from websocket import create_connection | |
| if pro_data: | |
| self._ws = create_connection( | |
| PRO_WS_URL, | |
| headers=json.dumps({"Origin": ORIGIN_URL}), | |
| timeout=WS_TIMEOUT, | |
| ) | |
| else: | |
| self._ws = create_connection( | |
| WS_URL, | |
| headers=json.dumps({"Origin": ORIGIN_URL}), | |
| timeout=WS_TIMEOUT, | |
| ) | |
| @classmethod | |
| def filter_raw_message(cls, text) -> tp.Tuple[str, str]: | |
| """Filter raw message.""" | |
| found = re.search('"m":"(.+?)",', text).group(1) | |
| found2 = re.search('"p":(.+?"}"])}', text).group(1) | |
| return found, found2 | |
| @classmethod | |
| def prepend_header(cls, st: str) -> str: | |
| """Prepend a header.""" | |
| return "~m~" + str(len(st)) + "~m~" + st | |
| @classmethod | |
| def construct_message(cls, func: str, param_list: tp.List[str]) -> str: | |
| """Construct a message.""" | |
| return json.dumps({"m": func, "p": param_list}, separators=(",", ":")) | |
| def create_message(self, func: str, param_list: tp.List[str]) -> str: | |
| """Create a message.""" | |
| return self.prepend_header(self.construct_message(func, param_list)) | |
| def send_message(self, func: str, param_list: tp.List[str]) -> None: | |
| """Send a message.""" | |
| m = self.create_message(func, param_list) | |
| self.ws.send(m) | |
| @classmethod | |
| def convert_raw_data(cls, raw_data: str, symbol: str) -> pd.DataFrame: | |
| """Process raw data into a DataFrame.""" | |
| search_result = re.search(r'"s":\[(.+?)\}\]', raw_data) | |
| if search_result is None: | |
| raise ValueError("Couldn't parse data returned by TradingView: {}".format(raw_data)) | |
| out = search_result.group(1) | |
| x = out.split(',{"') | |
| data = list() | |
| volume_data = True | |
| for xi in x: | |
| xi = re.split(r"\[|:|,|\]", xi) | |
| ts = datetime.datetime.utcfromtimestamp(float(xi[4])) | |
| row = [ts] | |
| for i in range(5, 10): | |
| # skip converting volume data if does not exists | |
| if not volume_data and i == 9: | |
| row.append(0.0) | |
| continue | |
| try: | |
| row.append(float(xi[i])) | |
| except ValueError: | |
| volume_data = False | |
| row.append(0.0) | |
| data.append(row) | |
| data = pd.DataFrame(data, columns=["datetime", "open", "high", "low", "close", "volume"]) | |
| data = data.set_index("datetime") | |
| data.insert(0, "symbol", value=symbol) | |
| return data | |
| @classmethod | |
| def format_symbol(cls, symbol: str, exchange: str, fut_contract: tp.Optional[int] = None) -> str: | |
| """Format a symbol.""" | |
| if ":" in symbol: | |
| pass | |
| elif fut_contract is None: | |
| symbol = f"{exchange}:{symbol}" | |
| elif isinstance(fut_contract, int): | |
| symbol = f"{exchange}:{symbol}{fut_contract}!" | |
| else: | |
| raise ValueError(f"Invalid fut_contract: '{fut_contract}'") | |
| return symbol | |
| def get_hist( | |
| self, | |
| symbol: str, | |
| exchange: str = "NSE", | |
| interval: str = "1D", | |
| fut_contract: tp.Optional[int] = None, | |
| adjustment: str = "splits", | |
| extended_session: bool = False, | |
| pro_data: bool = True, | |
| limit: int = 20000, | |
| return_raw: bool = False, | |
| ) -> tp.Union[str, tp.Frame]: | |
| """Get historical data.""" | |
| symbol = self.format_symbol(symbol=symbol, exchange=exchange, fut_contract=fut_contract) | |
| backadjustment = False | |
| if symbol.endswith("!A"): | |
| backadjustment = True | |
| symbol = symbol.replace("!A", "!") | |
| self.create_connection(pro_data=pro_data) | |
| self.send_message("set_auth_token", [self.auth_token]) | |
| self.send_message("chart_create_session", [self.chart_session, ""]) | |
| self.send_message("quote_create_session", [self.session]) | |
| self.send_message( | |
| "quote_set_fields", | |
| [ | |
| self.session, | |
| "ch", | |
| "chp", | |
| "current_session", | |
| "description", | |
| "local_description", | |
| "language", | |
| "exchange", | |
| "fractional", | |
| "is_tradable", | |
| "lp", | |
| "lp_time", | |
| "minmov", | |
| "minmove2", | |
| "original_name", | |
| "pricescale", | |
| "pro_name", | |
| "short_name", | |
| "type", | |
| "update_mode", | |
| "volume", | |
| "currency_code", | |
| "rchp", | |
| "rtc", | |
| ], | |
| ) | |
| self.send_message("quote_add_symbols", [self.session, symbol, {"flags": ["force_permission"]}]) | |
| self.send_message("quote_fast_symbols", [self.session, symbol]) | |
| self.send_message( | |
| "resolve_symbol", | |
| [ | |
| self.chart_session, | |
| "symbol_1", | |
| '={"symbol":"' | |
| + symbol | |
| + '","adjustment":"' | |
| + adjustment | |
| + ("" if not backadjustment else '","backadjustment":"default') | |
| + '","session":' | |
| + ('"regular"' if not extended_session else '"extended"') | |
| + "}", | |
| ], | |
| ) | |
| self.send_message("create_series", [self.chart_session, "s1", "s1", "symbol_1", interval, limit]) | |
| self.send_message("switch_timezone", [self.chart_session, "exchange"]) | |
| raw_data = "" | |
| while True: | |
| try: | |
| result = self.ws.recv() | |
| raw_data += result + "\n" | |
| except Exception as e: | |
| break | |
| if "series_completed" in result: | |
| break | |
| if return_raw: | |
| return raw_data | |
| return self.convert_raw_data(raw_data, symbol) | |
| @classmethod | |
| def search_symbol( | |
| cls, | |
| text: tp.Optional[str] = None, | |
| exchange: tp.Optional[str] = None, | |
| pages: tp.Optional[int] = None, | |
| delay: tp.Optional[int] = None, | |
| retries: int = 3, | |
| show_progress: bool = True, | |
| pbar_kwargs: tp.KwargsLike = None, | |
| ) -> tp.List[dict]: | |
| """Search for a symbol.""" | |
| if text is None: | |
| text = "" | |
| if exchange is None: | |
| exchange = "" | |
| if pbar_kwargs is None: | |
| pbar_kwargs = {} | |
| symbols_list = [] | |
| pbar = None | |
| pages_fetched = 0 | |
| while True: | |
| for i in range(retries): | |
| try: | |
| url = SEARCH_URL.format(text=text, exchange=exchange.upper(), start=len(symbols_list)) | |
| headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT} | |
| resp = requests.get(url, headers=headers) | |
| symbols_data = json.loads(resp.text.replace("</em>", "").replace("<em>", "")) | |
| break | |
| except json.JSONDecodeError as e: | |
| if i == retries - 1: | |
| raise e | |
| if delay is not None: | |
| time.sleep(delay) | |
| symbols_remaining = symbols_data.get("symbols_remaining", 0) | |
| new_symbols = symbols_data.get("symbols", []) | |
| symbols_list.extend(new_symbols) | |
| if pages is None and symbols_remaining > 0: | |
| show_pbar = True | |
| elif pages is not None and pages > 1: | |
| show_pbar = True | |
| else: | |
| show_pbar = False | |
| if pbar is None and show_pbar: | |
| if pages is not None: | |
| total = pages | |
| else: | |
| total = math.ceil((len(new_symbols) + symbols_remaining) / len(new_symbols)) | |
| pbar = ProgressBar( | |
| total=total, | |
| show_progress=show_progress, | |
| **pbar_kwargs, | |
| ) | |
| pbar.enter() | |
| if pbar is not None: | |
| max_symbols = len(symbols_list) + symbols_remaining | |
| if pages is not None: | |
| max_symbols = min(max_symbols, pages * len(new_symbols)) | |
| pbar.set_description(dict(symbols="%d/%d" % (len(symbols_list), max_symbols))) | |
| pbar.update() | |
| if symbols_remaining == 0: | |
| break | |
| pages_fetched += 1 | |
| if pages is not None and pages_fetched >= pages: | |
| break | |
| if delay is not None: | |
| time.sleep(delay) | |
| if pbar is not None: | |
| pbar.exit() | |
| return symbols_list | |
| @classmethod | |
| def scan_symbols(cls, market: tp.Optional[str] = None, **kwargs) -> tp.List[dict]: | |
| """Scan symbols in a region/market.""" | |
| if market is None: | |
| market = "global" | |
| url = SCAN_URL.format(market=market.lower()) | |
| headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT} | |
| resp = requests.post(url, json.dumps(kwargs), headers=headers) | |
| symbols_list = json.loads(resp.text)["data"] | |
| return symbols_list | |
| TVDataT = tp.TypeVar("TVDataT", bound="TVData") | |
| class TVData(RemoteData): | |
| """Data class for fetching from TradingView. | |
| See `TVData.fetch_symbol` for arguments. | |
| !!! note | |
| If you're getting the error "Please confirm that you are not a robot by clicking the captcha box." | |
| when attempting to authenticate, use `auth_token` instead of `username` and `password`. | |
| To get the authentication token, go to TradingView, log in, visit any chart, open your console's | |
| developer tools, and search for "auth_token". | |
| Usage: | |
| * Set up the credentials globally (optional): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.TVData.set_custom_settings( | |
| ... client_config=dict( | |
| ... username="YOUR_USERNAME", | |
| ... password="YOUR_PASSWORD", | |
| ... auth_token="YOUR_AUTH_TOKEN", # optional, instead of username and password | |
| ... ) | |
| ... ) | |
| ``` | |
| * Pull data: | |
| ```pycon | |
| >>> data = vbt.TVData.pull( | |
| ... "NASDAQ:AAPL", | |
| ... timeframe="1 hour" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.tv") | |
| @classmethod | |
| def list_symbols( | |
| cls, | |
| *, | |
| exchange_pattern: tp.Optional[str] = None, | |
| symbol_pattern: tp.Optional[str] = None, | |
| use_regex: bool = False, | |
| sort: bool = True, | |
| client: tp.Optional[TVClient] = None, | |
| client_config: tp.DictLike = None, | |
| text: tp.Optional[str] = None, | |
| exchange: tp.Optional[str] = None, | |
| pages: tp.Optional[int] = None, | |
| delay: tp.Optional[int] = None, | |
| retries: tp.Optional[int] = None, | |
| show_progress: tp.Optional[bool] = None, | |
| pbar_kwargs: tp.KwargsLike = None, | |
| market: tp.Optional[str] = None, | |
| markets: tp.Optional[tp.List[str]] = None, | |
| fields: tp.Optional[tp.MaybeIterable[str]] = None, | |
| filter_by: tp.Union[None, tp.Callable, CustomTemplate] = None, | |
| groups: tp.Optional[tp.MaybeIterable[tp.Dict[str, tp.MaybeIterable[str]]]] = None, | |
| template_context: tp.KwargsLike = None, | |
| return_field_data: bool = False, | |
| **scanner_kwargs, | |
| ) -> tp.Union[tp.List[str], tp.List[tp.Kwargs]]: | |
| """List all symbols. | |
| Uses symbol search when either `text` or `exchange` is provided (returns a subset of symbols). | |
| Otherwise, uses the market scanner (returns all symbols, big payload). | |
| When using the market scanner, use `market` to filter by one or multiple markets. For the list | |
| of available markets, see `MARKET_LIST`. | |
| Use `fields` to make the market scanner return additional information that can be used for | |
| filtering with `filter_by`. Such information is passed to the function as a dictionary where | |
| fields are keys. The function can also be a template that can use the same information provided | |
| as a context, or a list of values that should be matched against the values corresponding to their fields. | |
| For the list of available fields, see `FIELD_LIST`. Argument `fields` can also be "all". | |
| Set `return_field_data` to True to return a list with (filtered) field data. | |
| Use `groups` to provide a single dictionary or a list of dictionaries with groups. | |
| Each dictionary can be provided either in a compressed format, such as `dict(index=index)`, | |
| or in a full format, such as `dict(type="index", values=[index])`. | |
| Keyword arguments `scanner_kwargs` are encoded and passed directly to the market scanner. | |
| Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each exchange against | |
| `exchange_pattern` and each symbol against `symbol_pattern`. | |
| Usage: | |
| * List all symbols (market scanner): | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> vbt.TVData.list_symbols() | |
| ``` | |
| * Search for symbols matching a pattern (market scanner, client-side): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols(symbol_pattern="BTC*") | |
| ``` | |
| * Search for exchanges matching a pattern (market scanner, client-side): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols(exchange_pattern="NASDAQ") | |
| ``` | |
| * Search for symbols containing a text (symbol search, server-side): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols(text="BTC") | |
| ``` | |
| * List symbols from an exchange (symbol search): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols(exchange="NASDAQ") | |
| ``` | |
| * List symbols from a market (market scanner): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols(market="poland") | |
| ``` | |
| * List index constituents (market scanner): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols(groups=dict(index="NASDAQ:NDX")) | |
| ``` | |
| * Filter symbols by fields using a function (market scanner): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols( | |
| ... market="america", | |
| ... fields=["sector"], | |
| ... filter_by=lambda context: context["sector"] == "Technology Services" | |
| ... ) | |
| ``` | |
| * Filter symbols by fields using a template (market scanner): | |
| ```pycon | |
| >>> vbt.TVData.list_symbols( | |
| ... market="america", | |
| ... fields=["sector"], | |
| ... filter_by=vbt.RepEval("sector == 'Technology Services'") | |
| ... ) | |
| ``` | |
| """ | |
| pages = cls.resolve_custom_setting(pages, "pages", sub_path="search", sub_path_only=True) | |
| delay = cls.resolve_custom_setting(delay, "delay", sub_path="search", sub_path_only=True) | |
| retries = cls.resolve_custom_setting(retries, "retries", sub_path="search", sub_path_only=True) | |
| show_progress = cls.resolve_custom_setting( | |
| show_progress, "show_progress", sub_path="search", sub_path_only=True | |
| ) | |
| pbar_kwargs = cls.resolve_custom_setting( | |
| pbar_kwargs, "pbar_kwargs", sub_path="search", sub_path_only=True, merge=True | |
| ) | |
| markets = cls.resolve_custom_setting(markets, "markets", sub_path="scanner", sub_path_only=True) | |
| fields = cls.resolve_custom_setting(fields, "fields", sub_path="scanner", sub_path_only=True) | |
| filter_by = cls.resolve_custom_setting(filter_by, "filter_by", sub_path="scanner", sub_path_only=True) | |
| groups = cls.resolve_custom_setting(groups, "groups", sub_path="scanner", sub_path_only=True) | |
| template_context = cls.resolve_custom_setting( | |
| template_context, "template_context", sub_path="scanner", sub_path_only=True, merge=True | |
| ) | |
| scanner_kwargs = cls.resolve_custom_setting( | |
| scanner_kwargs, "scanner_kwargs", sub_path="scanner", sub_path_only=True, merge=True | |
| ) | |
| if market is None and text is None and exchange is None: | |
| market = "global" | |
| if market is not None and (text is not None or exchange is not None): | |
| raise ValueError("Please provide either market, or text and/or exchange") | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| if market is None: | |
| data = client.search_symbol( | |
| text=text, | |
| exchange=exchange, | |
| pages=pages, | |
| delay=delay, | |
| retries=retries, | |
| show_progress=show_progress, | |
| pbar_kwargs=pbar_kwargs, | |
| ) | |
| all_symbols = map(lambda x: x["exchange"] + ":" + x["symbol"], data) | |
| return_field_data = False | |
| else: | |
| if markets is not None: | |
| scanner_kwargs["markets"] = markets | |
| if fields is not None: | |
| if "columns" in scanner_kwargs: | |
| raise ValueError("Use fields instead of columns") | |
| if isinstance(fields, str): | |
| if fields.lower() == "all": | |
| fields = FIELD_LIST | |
| else: | |
| fields = [fields] | |
| scanner_kwargs["columns"] = fields | |
| if groups is not None: | |
| if isinstance(groups, dict): | |
| groups = [groups] | |
| new_groups = [] | |
| for group in groups: | |
| if "type" in group: | |
| new_groups.append(group) | |
| else: | |
| for k, v in group.items(): | |
| if isinstance(v, str): | |
| v = [v] | |
| new_groups.append(dict(type=k, values=v)) | |
| groups = new_groups | |
| if "symbols" in scanner_kwargs: | |
| scanner_kwargs["symbols"] = dict(scanner_kwargs["symbols"]) | |
| else: | |
| scanner_kwargs["symbols"] = dict() | |
| scanner_kwargs["symbols"]["groups"] = groups | |
| if filter_by is not None: | |
| if isinstance(filter_by, str): | |
| filter_by = [filter_by] | |
| data = client.scan_symbols(market.lower(), **scanner_kwargs) | |
| if data is None: | |
| raise ValueError("No data returned by TradingView") | |
| all_symbols = [] | |
| for item in data: | |
| if fields is not None: | |
| item = {"symbol": item["s"], **dict(zip(fields, item["d"]))} | |
| else: | |
| item = {"symbol": item["s"]} | |
| if filter_by is not None: | |
| if fields is not None: | |
| context = merge_dicts(item, template_context) | |
| else: | |
| raise ValueError("Must provide fields for filter_by") | |
| if isinstance(filter_by, CustomTemplate): | |
| if not filter_by.substitute(context, eval_id="filter_by"): | |
| continue | |
| elif callable(filter_by): | |
| if not filter_by(context): | |
| continue | |
| else: | |
| if len(fields) != len(filter_by): | |
| raise ValueError("Fields and filter_by must have the same number of values") | |
| conditions_met = True | |
| for i in range(len(fields)): | |
| if context[fields[i]] != filter_by[i]: | |
| conditions_met = False | |
| break | |
| if not conditions_met: | |
| continue | |
| if return_field_data: | |
| all_symbols.append(item) | |
| else: | |
| all_symbols.append(item["symbol"]) | |
| found_symbols = [] | |
| for symbol in all_symbols: | |
| if return_field_data: | |
| item = symbol | |
| symbol = item["symbol"] | |
| else: | |
| item = symbol | |
| if '"symbol"' in symbol: | |
| continue | |
| if exchange_pattern is not None: | |
| if not cls.key_match(symbol.split(":")[0], exchange_pattern, use_regex=use_regex): | |
| continue | |
| if symbol_pattern is not None: | |
| if not cls.key_match(symbol.split(":")[1], symbol_pattern, use_regex=use_regex): | |
| continue | |
| found_symbols.append(item) | |
| if sort: | |
| if return_field_data: | |
| return sorted(found_symbols, key=lambda x: x["symbol"]) | |
| return sorted(dict.fromkeys(found_symbols)) | |
| if return_field_data: | |
| return found_symbols | |
| return list(dict.fromkeys(found_symbols)) | |
| @classmethod | |
| def resolve_client(cls, client: tp.Optional[TVClient] = None, **client_config) -> TVClient: | |
| """Resolve the client. | |
| If provided, must be of the type `TVClient`. Otherwise, will be created using `client_config`.""" | |
| client = cls.resolve_custom_setting(client, "client") | |
| if client_config is None: | |
| client_config = {} | |
| has_client_config = len(client_config) > 0 | |
| client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) | |
| if client is None: | |
| client = TVClient(**client_config) | |
| elif has_client_config: | |
| raise ValueError("Cannot apply client_config to already initialized client") | |
| return client | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| client: tp.Optional[TVClient] = None, | |
| client_config: tp.KwargsLike = None, | |
| exchange: tp.Optional[str] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| fut_contract: tp.Optional[int] = None, | |
| adjustment: tp.Optional[str] = None, | |
| extended_session: tp.Optional[bool] = None, | |
| pro_data: tp.Optional[bool] = None, | |
| limit: tp.Optional[int] = None, | |
| delay: tp.Optional[int] = None, | |
| retries: tp.Optional[int] = None, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from TradingView. | |
| Args: | |
| symbol (str): Symbol. | |
| Symbol must be in the `EXCHANGE:SYMBOL` format if `exchange` is None. | |
| client (TVClient): Client. | |
| See `TVData.resolve_client`. | |
| client_config (dict): Client config. | |
| See `TVData.resolve_client`. | |
| exchange (str): Exchange. | |
| Can be omitted if already provided via `symbol`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| fut_contract (int): None for cash, 1 for continuous current contract in front, | |
| 2 for continuous next contract in front. | |
| adjustment (str): Adjustment. | |
| Either "splits" (default) or "dividends". | |
| extended_session (bool): Regular session if False, extended session if True. | |
| pro_data (bool): Whether to use pro data. | |
| limit (int): The maximum number of returned items. | |
| delay (float): Time to sleep after each request (in seconds). | |
| retries (int): The number of retries on failure to fetch data. | |
| For defaults, see `custom.tv` in `vectorbtpro._settings.data`. | |
| """ | |
| if client_config is None: | |
| client_config = {} | |
| client = cls.resolve_client(client=client, **client_config) | |
| exchange = cls.resolve_custom_setting(exchange, "exchange") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| fut_contract = cls.resolve_custom_setting(fut_contract, "fut_contract") | |
| adjustment = cls.resolve_custom_setting(adjustment, "adjustment") | |
| extended_session = cls.resolve_custom_setting(extended_session, "extended_session") | |
| pro_data = cls.resolve_custom_setting(pro_data, "pro_data") | |
| limit = cls.resolve_custom_setting(limit, "limit") | |
| delay = cls.resolve_custom_setting(delay, "delay") | |
| retries = cls.resolve_custom_setting(retries, "retries") | |
| freq = timeframe | |
| if not isinstance(timeframe, str): | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| split = dt.split_freq_str(timeframe) | |
| if split is None: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| multiplier, unit = split | |
| if unit == "s": | |
| interval = f"{str(multiplier)}S" | |
| elif unit == "m": | |
| interval = str(multiplier) | |
| elif unit == "h": | |
| interval = f"{str(multiplier)}H" | |
| elif unit == "D": | |
| interval = f"{str(multiplier)}D" | |
| elif unit == "W": | |
| interval = f"{str(multiplier)}W" | |
| elif unit == "M": | |
| interval = f"{str(multiplier)}M" | |
| else: | |
| raise ValueError(f"Invalid timeframe: '{timeframe}'") | |
| for i in range(retries): | |
| try: | |
| df = client.get_hist( | |
| symbol=symbol, | |
| exchange=exchange, | |
| interval=interval, | |
| fut_contract=fut_contract, | |
| adjustment=adjustment, | |
| extended_session=extended_session, | |
| pro_data=pro_data, | |
| limit=limit, | |
| ) | |
| break | |
| except Exception as e: | |
| if i == retries - 1: | |
| raise e | |
| if delay is not None: | |
| time.sleep(delay) | |
| df.rename( | |
| columns={ | |
| "symbol": "Symbol", | |
| "open": "Open", | |
| "high": "High", | |
| "low": "Low", | |
| "close": "Close", | |
| "volume": "Volume", | |
| }, | |
| inplace=True, | |
| ) | |
| if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: | |
| df = df.tz_localize("utc") | |
| if "Symbol" in df: | |
| del df["Symbol"] | |
| if "Open" in df.columns: | |
| df["Open"] = df["Open"].astype(float) | |
| if "High" in df.columns: | |
| df["High"] = df["High"].astype(float) | |
| if "Low" in df.columns: | |
| df["Low"] = df["Low"].astype(float) | |
| if "Close" in df.columns: | |
| df["Close"] = df["Close"].astype(float) | |
| if "Volume" in df.columns: | |
| df["Volume"] = df["Volume"].astype(float) | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| </file> | |
| <file path="data/custom/yf.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Module with `YFData`.""" | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.data.custom.remote import RemoteData | |
| from vectorbtpro.generic import nb as generic_nb | |
| from vectorbtpro.utils import datetime_ as dt | |
| from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig | |
| from vectorbtpro.utils.parsing import get_func_kwargs | |
| __all__ = [ | |
| "YFData", | |
| ] | |
| __pdoc__ = {} | |
| class YFData(RemoteData): | |
| """Data class for fetching from Yahoo Finance. | |
| See https://github.com/ranaroussi/yfinance for API. | |
| See `YFData.fetch_symbol` for arguments. | |
| Usage: | |
| ```pycon | |
| >>> from vectorbtpro import * | |
| >>> data = vbt.YFData.pull( | |
| ... "BTC-USD", | |
| ... start="2020-01-01", | |
| ... end="2021-01-01", | |
| ... timeframe="1 day" | |
| ... ) | |
| ``` | |
| """ | |
| _settings_path: tp.SettingsPath = dict(custom="data.custom.yf") | |
| _feature_config: tp.ClassVar[Config] = HybridConfig( | |
| { | |
| "Dividends": dict( | |
| resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( | |
| resampler, | |
| generic_nb.sum_reduce_nb, | |
| ) | |
| ), | |
| "Stock Splits": dict( | |
| resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( | |
| resampler, | |
| generic_nb.nonzero_prod_reduce_nb, | |
| ) | |
| ), | |
| "Capital Gains": dict( | |
| resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( | |
| resampler, | |
| generic_nb.sum_reduce_nb, | |
| ) | |
| ), | |
| } | |
| ) | |
| @property | |
| def feature_config(self) -> Config: | |
| return self._feature_config | |
| @classmethod | |
| def fetch_symbol( | |
| cls, | |
| symbol: str, | |
| period: tp.Optional[str] = None, | |
| start: tp.Optional[tp.DatetimeLike] = None, | |
| end: tp.Optional[tp.DatetimeLike] = None, | |
| timeframe: tp.Optional[str] = None, | |
| tz: tp.TimezoneLike = None, | |
| **history_kwargs, | |
| ) -> tp.SymbolData: | |
| """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Yahoo Finance. | |
| Args: | |
| symbol (str): Symbol. | |
| period (str): Period. | |
| start (any): Start datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| end (any): End datetime. | |
| See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. | |
| timeframe (str): Timeframe. | |
| Allows human-readable strings such as "15 minutes". | |
| tz (any): Timezone. | |
| See `vectorbtpro.utils.datetime_.to_timezone`. | |
| **history_kwargs: Keyword arguments passed to `yfinance.base.TickerBase.history`. | |
| For defaults, see `custom.yf` in `vectorbtpro._settings.data`. | |
| !!! warning | |
| Data coming from Yahoo is not the most stable data out there. Yahoo may manipulate data | |
| how they want, add noise, return missing data points (see volume in the example below), etc. | |
| It's only used in vectorbt for demonstration purposes. | |
| """ | |
| from vectorbtpro.utils.module_ import assert_can_import | |
| assert_can_import("yfinance") | |
| import yfinance as yf | |
| period = cls.resolve_custom_setting(period, "period") | |
| start = cls.resolve_custom_setting(start, "start") | |
| end = cls.resolve_custom_setting(end, "end") | |
| timeframe = cls.resolve_custom_setting(timeframe, "timeframe") | |
| tz = cls.resolve_custom_setting(tz, "tz") | |
| history_kwargs = cls.resolve_custom_setting(history_kwargs, "history_kwargs", merge=True) | |
| ticker = yf.Ticker(symbol) | |
| def_history_kwargs = get_func_kwargs(yf.Tickers.history) | |
| ticker_tz = ticker._get_ticker_tz( | |
| history_kwargs.get("proxy", def_history_kwargs["proxy"]), | |
| history_kwargs.get("timeout", def_history_kwargs["timeout"]), | |
| ) | |
| if tz is None: | |
| tz = ticker_tz | |
| if start is not None: | |
| start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=ticker_tz) | |
| if end is not None: | |
| end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=ticker_tz) | |
| freq = timeframe | |
| split = dt.split_freq_str(timeframe) | |
| if split is not None: | |
| multiplier, unit = split | |
| if unit == "D": | |
| unit = "d" | |
| elif unit == "W": | |
| unit = "wk" | |
| elif unit == "M": | |
| unit = "mo" | |
| timeframe = str(multiplier) + unit | |
| df = ticker.history(period=period, start=start, end=end, interval=timeframe, **history_kwargs) | |
| if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: | |
| df = df.tz_localize(ticker_tz) | |
| if not df.empty: | |
| if start is not None: | |
| if df.index[0] < start: | |
| df = df[df.index >= start] | |
| if end is not None: | |
| if df.index[-1] >= end: | |
| df = df[df.index < end] | |
| return df, dict(tz=tz, freq=freq) | |
| def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: | |
| fetch_kwargs = self.select_fetch_kwargs(symbol) | |
| fetch_kwargs["start"] = self.select_last_index(symbol) | |
| kwargs = merge_dicts(fetch_kwargs, kwargs) | |
| return self.fetch_symbol(symbol, **kwargs) | |
| YFData.override_feature_config_doc(__pdoc__) | |
| </file> | |
| <file path="data/__init__.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Modules for working with data sources.""" | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from vectorbtpro.data.base import * | |
| from vectorbtpro.data.custom import * | |
| from vectorbtpro.data.decorators import * | |
| from vectorbtpro.data.nb import * | |
| from vectorbtpro.data.saver import * | |
| from vectorbtpro.data.updater import * | |
| </file> | |
| <file path="data/base.py"> | |
| # ==================================== VBTPROXYZ ==================================== | |
| # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. | |
| # | |
| # This file is part of the proprietary VectorBT® PRO package and is licensed under | |
| # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ | |
| # | |
| # Unauthorized publishing, distribution, sublicensing, or sale of this software | |
| # or its parts is strictly prohibited. | |
| # =================================================================================== | |
| """Base class for working with data sources.""" | |
| import inspect | |
| import string | |
| import traceback | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from vectorbtpro import _typing as tp | |
| from vectorbtpro.base.indexes import stack_indexes | |
| from vectorbtpro.base.merging import column_stack_arrays, is_merge_func_from_config | |
| from vectorbtpro.base.reshaping import to_any_array, to_pd_array, to_1d_array, to_2d_array, broadcast_to | |
| from vectorbtpro.base.wrapping import ArrayWrapper | |
| from vectorbtpro.data.decorators import attach_symbol_dict_methods | |
| from vectorbtpro.generic import nb as generic_nb | |
| from vectorbtpro.generic.analyzable import Analyzable | |
| from vectorbtpro.generic.drawdowns import Drawdowns | |
| from vectorbtpro.ohlcv.nb import mirror_ohlc_nb | |
| from vectorbtpro.ohlcv.enums import PriceFeature | |
| from vectorbtpro.returns.accessors import ReturnsAccessor | |
| from vectorbtpro.utils import checks, datetime_ as dt | |
| from vectorbtpro.utils.attr_ import get_dict_attr | |
| from vectorbtpro.utils.base import Base | |
| from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig, copy_dict | |
| from vectorbtpro.utils.decorators import cached_property, hybrid_method | |
| from vectorbtpro.utils.enum_ import map_enum_fields | |
| from vectorbtpro.utils.execution import Task, NoResult, NoResultsException, filter_out_no_results, execute | |
| from vectorbtpro.utils.merging import MergeFunc | |
| from vectorbtpro.utils.parsing import get_func_arg_names, extend_args | |
| from vectorbtpro.utils.path_ import check_mkdir | |
| from vectorbtpro.utils.pickling import pdict, RecState | |
| from vectorbtpro.utils.template import Rep, RepEval, CustomTemplate, substitute_templates | |
| from vectorbtpro.utils.warnings_ import warn | |
| from vectorbtpro.registries.ch_registry import ch_reg | |
| from vectorbtpro.registries.jit_registry import jit_reg | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from sqlalchemy import Engine as EngineT | |
| except ImportError: | |
| EngineT = "Engine" | |
| try: | |
| if not tp.TYPE_CHECKING: | |
| raise ImportError | |
| from duckdb import DuckDBPyConnection as DuckDBPyConnectionT | |
| except ImportError: | |
| DuckDBPyConnectionT = "DuckDBPyConnection" | |
| __all__ = [ | |
| "key_dict", | |
| "feature_dict", | |
| "symbol_dict", | |
| "run_func_dict", | |
| "run_arg_dict", | |
| "Data", | |
| ] | |
| __pdoc__ = {} | |
| class key_dict(pdict): | |
| """Dict that contains features or symbols as keys.""" | |
| pass | |
| class feature_dict(key_dict): | |
| """Dict that contains features as keys.""" | |
| pass | |
| class symbol_dict(key_dict): | |
| """Dict that contains symbols as keys.""" | |
| pass | |
| class run_func_dict(pdict): | |
| """Dict that contains function names as keys for `Data.run`.""" | |
| pass | |
| class run_arg_dict(pdict): | |
| """Dict that contains argument names as keys for `Data.run`.""" | |
| pass | |
| BaseDataMixinT = tp.TypeVar("BaseDataMixinT", bound="BaseDataMixin") | |
| class BaseDataMixin(Base): | |
| """Base mixin class for working with data.""" | |
| @property | |
| def feature_wrapper(self) -> ArrayWrapper: | |
| """Column wrapper.""" | |
| raise NotImplementedError | |
| @property | |
| def symbol_wrapper(self) -> ArrayWrapper: | |
| """Symbol wrapper.""" | |
| raise NotImplementedError | |
| @property | |
| def features(self) -> tp.List[tp.Feature]: | |
| """List of features.""" | |
| return self.feature_wrapper.columns.tolist() | |
| @property | |
| def symbols(self) -> tp.List[tp.Symbol]: | |
| """List of symbols.""" | |
| return self.symbol_wrapper.columns.tolist() | |
| @classmethod | |
| def has_multiple_keys(cls, keys: tp.MaybeKeys) -> bool: | |
| """Check whether there are one or multiple keys.""" | |
| if checks.is_hashable(keys): | |
| return False | |
| elif checks.is_sequence(keys): | |
| return True | |
| raise TypeError("Keys must be either a hashable or a sequence of hashable") | |
| @classmethod | |
| def prepare_key(cls, key: tp.Key) -> tp.Key: | |
| """Prepare a key.""" | |
| if isinstance(key, tuple): | |
| return tuple([cls.prepare_key(k) for k in key]) | |
| if isinstance(key, str): | |
| return key.lower().strip().replace(" ", "_") | |
| return key | |
| def get_feature_idx(self, feature: tp.Feature, raise_error: bool = False) -> int: | |
| """Return the index of a feature.""" | |
| # shortcut | |
| columns = self.feature_wrapper.columns | |
| if not columns.has_duplicates: | |
| if feature in columns: | |
| return columns.get_loc(feature) | |
| feature = self.prepare_key(feature) | |
| found_indices = [] | |
| for i, c in enumerate(self.features): | |
| c = self.prepare_key(c) | |
| if feature == c: | |
| found_indices.append(i) | |
| if len(found_indices) == 0: | |
| if raise_error: | |
| raise ValueError(f"No features match the feature '{str(feature)}'") | |
| return -1 | |
| if len(found_indices) == 1: | |
| return found_indices[0] | |
| raise ValueError(f"Multiple features match the feature '{str(feature)}'") | |
| def get_symbol_idx(self, symbol: tp.Symbol, raise_error: bool = False) -> int: | |
| """Return the index of a symbol.""" | |
| # shortcut | |
| columns = self.symbol_wrapper.columns | |
| if not columns.has_duplicates: | |
| if symbol in columns: | |
| return columns.get_loc(symbol) | |
| symbol = self.prepare_key(symbol) | |
| found_indices = [] | |
| for i, c in enumerate(self.symbols): | |
| c = self.prepare_key(c) | |
| if symbol == c: | |
| found_indices.append(i) | |
| if len(found_indices) == 0: | |
| if raise_error: | |
| raise ValueError(f"No symbols match the symbol '{str(symbol)}'") | |
| return -1 | |
| if len(found_indices) == 1: | |
| return found_indices[0] | |
| raise ValueError(f"Multiple symbols match the symbol '{str(symbol)}'") | |
| def select_feature_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT: | |
| """Select one or more features by index. | |
| Returns a new instance.""" | |
| raise NotImplementedError | |
| def select_symbol_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT: | |
| """Select one or more symbols by index. | |
| Returns a new instance.""" | |
| raise NotImplementedError | |
| def select_features(self: BaseDataMixinT, features: tp.MaybeFeatures, **kwargs) -> BaseDataMixinT: | |
| """Select one or more features. | |
| Returns a new instance.""" | |
| if self.has_multiple_keys(features): | |
| feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features] | |
| else: | |
| feature_idxs = self.get_feature_idx(features, raise_error=True) | |
| return self.select_feature_idxs(feature_idxs, **kwargs) | |
| def select_symbols(self: BaseDataMixinT, symbols: tp.MaybeSymbols, **kwargs) -> BaseDataMixinT: | |
| """Select one or more symbols. | |
| Returns a new instance.""" | |
| if self.has_multiple_keys(symbols): | |
| symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols] | |
| else: | |
| symbol_idxs = self.get_symbol_idx(symbols, raise_error=True) | |
| return self.select_symbol_idxs(symbol_idxs, **kwargs) | |
| def get( | |
| self, | |
| features: tp.Optional[tp.MaybeFeatures] = None, | |
| symbols: tp.Optional[tp.MaybeSymbols] = None, | |
| feature: tp.Optional[tp.Feature] = None, | |
| symbol: tp.Optional[tp.Symbol] = None, | |
| **kwargs, | |
| ) -> tp.MaybeTuple[tp.SeriesFrame]: | |
| """Get one or more features of one or more symbols of data.""" | |
| raise NotImplementedError | |
| def has_feature(self, feature: tp.Feature) -> bool: | |
| """Whether feature exists.""" | |
| feature_idx = self.get_feature_idx(feature, raise_error=False) | |
| return feature_idx != -1 | |
| def has_symbol(self, symbol: tp.Symbol) -> bool: | |
| """Whether symbol exists.""" | |
| symbol_idx = self.get_symbol_idx(symbol, raise_error=False) | |
| return symbol_idx != -1 | |
| def assert_has_feature(self, feature: tp.Feature) -> None: | |
| """Assert that feature exists.""" | |
| self.get_feature_idx(feature, raise_error=True) | |
| def assert_has_symbol(self, symbol: tp.Symbol) -> None: | |
| """Assert that symbol exists.""" | |
| self.get_symbol_idx(symbol, raise_error=True) | |
| def get_feature( | |
| self, | |
| feature: tp.Union[int, tp.Feature], | |
| raise_error: bool = False, | |
| ) -> tp.Optional[tp.SeriesFrame]: | |
| """Get feature that match a feature index or label.""" | |
| if checks.is_int(feature): | |
| return self.get(features=self.features[feature]) | |
| feature_idx = self.get_feature_idx(feature, raise_error=raise_error) | |
| if feature_idx == -1: | |
| return None | |
| return self.get(features=self.features[feature_idx]) | |
| def get_symbol( | |
| self, | |
| symbol: tp.Union[int, tp.Symbol], | |
| raise_error: bool = False, | |
| ) -> tp.Optional[tp.SeriesFrame]: | |
| """Get symbol that match a symbol index or label.""" | |
| if checks.is_int(symbol): | |
| return self.get(symbol=self.symbols[symbol]) | |
| symbol_idx = self.get_symbol_idx(symbol, raise_error=raise_error) | |
| if symbol_idx == -1: | |
| return None | |
| return self.get(symbol=self.symbols[symbol_idx]) | |
| OHLCDataMixinT = tp.TypeVar("OHLCDataMixinT", bound="OHLCDataMixin") | |
| class OHLCDataMixin(BaseDataMixin): | |
| """Mixin class for working with OHLC data.""" | |
| @property | |
| def open(self) -> tp.Optional[tp.SeriesFrame]: | |
| """Open.""" | |
| return self.get_feature("Open") | |
| @property | |
| def high(self) -> tp.Optional[tp.SeriesFrame]: | |
| """High.""" | |
| return self.get_feature("High") | |
| @property | |
| def low(self) -> tp.Optional[tp.SeriesFrame]: | |
| """Low.""" | |
| return self.get_feature("Low") | |
| @property | |
| def close(self) -> tp.Optional[tp.SeriesFrame]: | |
| """Close.""" | |
| return self.get_feature("Close") | |
| @property | |
| def volume(self) -> tp.Optional[tp.SeriesFrame]: | |
| """Volume.""" | |
| return self.get_feature("Volume") | |
| @property | |
| def trade_count(self) -> tp.Optional[tp.SeriesFrame]: | |
| """Trade count.""" | |
| return self.get_feature("Trade count") | |
| @property | |
| def vwap(self) -> tp.Optional[tp.SeriesFrame]: | |
| """VWAP.""" | |
| return self.get_feature("VWAP") | |
| @property | |
| def hlc3(self) -> tp.SeriesFrame: | |
| """HLC/3.""" | |
| high = self.get_feature("High", raise_error=True) | |
| low = self.get_feature("Low", raise_error=True) | |
| close = self.get_feature("Close", raise_error=True) | |
| return (high + low + close) / 3 | |
| @property | |
| def ohlc4(self) -> tp.SeriesFrame: | |
| """OHLC/4.""" | |
| open = self.get_feature("Open", raise_error=True) | |
| high = self.get_feature("High", raise_error=True) | |
| low = self.get_feature("Low", raise_error=True) | |
| close = self.get_feature("Close", raise_error=True) | |
| return (open + high + low + close) / 4 | |
| @property | |
| def has_any_ohlc(self) -> bool: | |
| """Whether the instance has any of the OHLC features.""" | |
| return ( | |
| self.has_feature("Open") or self.has_feature("High") or self.has_feature("Low") or self.has_feature("Close") | |
| ) | |
| @property | |
| def has_ohlc(self) -> bool: | |
| """Whether the instance has all the OHLC features.""" | |
| return ( | |
| self.has_feature("Open") | |
| and self.has_feature("High") | |
| and self.has_feature("Low") | |
| and self.has_feature("Close") | |
| ) | |
| @property | |
| def has_any_ohlcv(self) -> bool: | |
| """Whether the instance has any of the OHLCV features.""" | |
| return self.has_any_ohlc or self.has_feature("Volume") | |
| @property | |
| def has_ohlcv(self) -> bool: | |
| """Whether the instance has all the OHLCV features.""" | |
| return self.has_ohlc and self.has_feature("Volume") | |
| @property | |
| def ohlc(self: OHLCDataMixinT) -> OHLCDataMixinT: | |
| """Return a `OHLCDataMixin` instance with the OHLC features only.""" | |
| open_idx = self.get_feature_idx("Open", raise_error=True) | |
| high_idx = self.get_feature_idx("High", raise_error=True) | |
| low_idx = self.get_feature_idx("Low", raise_error=True) | |
| close_idx = self.get_feature_idx("Close", raise_error=True) | |
| return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx]) | |
| @property | |
| def ohlcv(self: OHLCDataMixinT) -> OHLCDataMixinT: | |
| """Return a `OHLCDataMixin` instance with the OHLCV features only.""" | |
| open_idx = self.get_feature_idx("Open", raise_error=True) | |
| high_idx = self.get_feature_idx("High", raise_error=True) | |
| low_idx = self.get_feature_idx("Low", raise_error=True) | |
| close_idx = self.get_feature_idx("Close", raise_error=True) | |
| volume_idx = self.get_feature_idx("Volume", raise_error=True) | |
| return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx, volume_idx]) | |
| def get_returns_acc(self, **kwargs) -> ReturnsAccessor: | |
| """Return accessor of type `vectorbtpro.returns.accessors.ReturnsAccessor`.""" | |
| return ReturnsAccessor.from_value( | |
| self.get_feature("Close", raise_error=True), | |
| wrapper=self.symbol_wrapper, | |
| return_values=False, | |
| **kwargs, | |
| ) | |
| @property | |
| def returns_acc(self) -> ReturnsAccessor: | |
| """`OHLCDataMixin.get_returns_acc` with default arguments.""" | |
| return self.get_returns_acc() | |
| def get_returns(self, **kwargs) -> tp.SeriesFrame: | |
| """Returns.""" | |
| return ReturnsAccessor.from_value( | |
| self.get_feature("Close", raise_error=True), | |
| wrapper=self.symbol_wrapper, | |
| return_values=True, | |
| **kwargs, | |
| ) | |
| @property | |
| def returns(self) -> tp.SeriesFrame: | |
| """`OHLCDataMixin.get_returns` with default arguments.""" | |
| return self.get_returns() | |
| def get_log_returns(self, **kwargs) -> tp.SeriesFrame: | |
| """Log returns.""" | |
| return ReturnsAccessor.from_value( | |
| self.get_feature("Close", raise_error=True), | |
| wrapper=self.symbol_wrapper, | |
| return_values=True, | |
| log_returns=True, | |
| **kwargs, | |
| ) | |
| @property | |
| def log_returns(self) -> tp.SeriesFrame: | |
| """`OHLCDataMixin.get_log_returns` with default arguments.""" | |
| return self.get_log_returns() | |
| def get_daily_returns(self, **kwargs) -> tp.SeriesFrame: | |
| """Daily returns.""" | |
| return ReturnsAccessor.from_value( | |
| self.get_feature("Close", raise_error=True), | |
| wrapper=self.symbol_wrapper, | |
| return_values=False, | |
| **kwargs, | |
| ).daily() | |
| @property | |
| def daily_returns(self) -> tp.SeriesFrame: | |
| """`OHLCDataMixin.get_daily_returns` with default arguments.""" | |
| return self.get_daily_returns() | |
| def get_daily_log_returns(self, **kwargs) -> tp.SeriesFrame: | |
| """Daily log returns.""" | |
| return ReturnsAccessor.from_value( | |
| self.get_feature("Close", raise_error=True), | |
| wrapper=self.symbol_wrapper, | |
| return_values=False, | |
| log_returns=True, | |
| **kwargs, | |
| ).daily() | |
| @property | |
| def daily_log_returns(self) -> tp.SeriesFrame: | |
| """`OHLCDataMixin.get_daily_log_returns` with default arguments.""" | |
| return self.get_daily_log_returns() | |
| def get_drawdowns(self, **kwargs) -> Drawdowns: | |
| """Generate drawdown records. | |
| See `vectorbtpro.generic.drawdowns.Drawdowns`.""" | |
| return Drawdowns.from_price( | |
| open=self.get_feature("Open", raise_error=True), | |
| high=self.get_feature("High", raise_error=True), | |
| low=self.get_feature("Low", raise_error=True), | |
| close=self.get_feature("Close", raise_error=True), | |
| **kwargs, | |
| ) | |
| @property | |
| def drawdowns(self) -> Drawdowns: | |
| """`OHLCDataMixin.get_drawdowns` with default arguments.""" | |
| return self.get_drawdowns() | |
| DataT = tp.TypeVar("DataT", bound="Data") | |
| class MetaData(type(Analyzable)): | |
| """Metaclass for `Data`.""" | |
| @property | |
| def feature_config(cls) -> Config: | |
| """Feature config.""" | |
| return cls._feature_config | |
| @attach_symbol_dict_methods | |
| class Data(Analyzable, OHLCDataMixin, metaclass=MetaData): | |
| """Class that downloads, updates, and manages data coming from a data source.""" | |
| _settings_path: tp.SettingsPath = dict(base="data") | |
| _writeable_attrs: tp.WriteableAttrs = {"_feature_config"} | |
| _feature_config: tp.ClassVar[Config] = HybridConfig() | |
| _key_dict_attrs = [ | |
| "fetch_kwargs", | |
| "returned_kwargs", | |
| "last_index", | |
| "delisted", | |
| "classes", | |
| ] | |
| """Attributes that subclass either `feature_dict` or `symbol_dict`.""" | |
| _data_dict_type_attrs = [ | |
| "classes", | |
| ] | |
| """Attributes that subclass the data dict type.""" | |
| _updatable_attrs = [ | |
| "fetch_kwargs", | |
| "returned_kwargs", | |
| "classes", | |
| ] | |
| """Attributes that have a method for updating.""" | |
| @property | |
| def feature_config(self) -> Config: | |
| """Column config of `${cls_name}`. | |
| ```python | |
| ${feature_config} | |
| ``` | |
| Returns `${cls_name}._feature_config`, which gets (hybrid-) copied upon creation of each instance. | |
| Thus, changing this config won't affect the class. | |
| To change fields, you can either change the config in-place, override this property, | |
| or overwrite the instance variable `${cls_name}._feature_config`. | |
| """ | |
| return self._feature_config | |
| def use_feature_config_of(self, cls: tp.Type[DataT]) -> None: | |
| """Copy feature config from another `Data` class.""" | |
| self._feature_config = cls.feature_config.copy() | |
| @classmethod | |
| def modify_state(cls, rec_state: RecState) -> RecState: | |
| # Ensure backward compatibility | |
| if "_column_config" in rec_state.attr_dct and "_feature_config" not in rec_state.attr_dct: | |
| new_attr_dct = dict(rec_state.attr_dct) | |
| new_attr_dct["_feature_config"] = new_attr_dct.pop("_column_config") | |
| rec_state = RecState( | |
| init_args=rec_state.init_args, | |
| init_kwargs=rec_state.init_kwargs, | |
| attr_dct=new_attr_dct, | |
| ) | |
| if "single_symbol" in rec_state.init_kwargs and "single_key" not in rec_state.init_kwargs: | |
| new_init_kwargs = dict(rec_state.init_kwargs) | |
| new_init_kwargs["single_key"] = new_init_kwargs.pop("single_symbol") | |
| rec_state = RecState( | |
| init_args=rec_state.init_args, | |
| init_kwargs=new_init_kwargs, | |
| attr_dct=rec_state.attr_dct, | |
| ) | |
| if "symbol_classes" in rec_state.init_kwargs and "classes" not in rec_state.init_kwargs: | |
| new_init_kwargs = dict(rec_state.init_kwargs) | |
| new_init_kwargs["classes"] = new_init_kwargs.pop("symbol_classes") | |
| rec_state = RecState( | |
| init_args=rec_state.init_args, | |
| init_kwargs=new_init_kwargs, | |
| attr_dct=rec_state.attr_dct, | |
| ) | |
| return rec_state | |
| @classmethod | |
| def fix_data_dict_type(cls, data: dict) -> tp.Union[feature_dict, symbol_dict]: | |
| """Fix dict type for data.""" | |
| checks.assert_instance_of(data, dict, arg_name="data") | |
| if not isinstance(data, key_dict): | |
| data = symbol_dict(data) | |
| return data | |
| @classmethod | |
| def fix_dict_types_in_kwargs( | |
| cls, | |
| data_type: tp.Type[tp.Union[feature_dict, symbol_dict]], | |
| **kwargs: tp.Kwargs, | |
| ) -> tp.Kwargs: | |
| """Fix dict types in keyword arguments.""" | |
| for attr in cls._key_dict_attrs: | |
| if attr in kwargs: | |
| attr_value = kwargs[attr] | |
| if attr_value is None: | |
| attr_value = {} | |
| checks.assert_instance_of(attr_value, dict, arg_name=attr) | |
| if not isinstance(attr_value, key_dict): | |
| attr_value = data_type(attr_value) | |
| if attr in cls._data_dict_type_attrs: | |
| checks.assert_instance_of(attr_value, data_type, arg_name=attr) | |
| kwargs[attr] = attr_value | |
| return kwargs | |
| @hybrid_method | |
| def row_stack( | |
| cls_or_self: tp.MaybeType[DataT], | |
| *objs: tp.MaybeTuple[DataT], | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> DataT: | |
| """Stack multiple `Data` instances along rows. | |
| Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.""" | |
| if not isinstance(cls_or_self, type): | |
| objs = (cls_or_self, *objs) | |
| cls = type(cls_or_self) | |
| else: | |
| cls = cls_or_self | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| for obj in objs: | |
| if not checks.is_instance_of(obj, Data): | |
| raise TypeError("Each object to be merged must be an instance of Data") | |
| if "wrapper" not in kwargs: | |
| if wrapper_kwargs is None: | |
| wrapper_kwargs = {} | |
| kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) | |
| keys = set() | |
| for obj in objs: | |
| keys = keys.union(set(obj.data.keys())) | |
| data_type = None | |
| for obj in objs: | |
| if len(keys.difference(set(obj.data.keys()))) > 0: | |
| if isinstance(obj.data, feature_dict): | |
| raise ValueError("Objects to be merged must have the same features") | |
| else: | |
| raise ValueError("Objects to be merged must have the same symbols") | |
| if data_type is None: | |
| data_type = type(obj.data) | |
| elif not isinstance(obj.data, data_type): | |
| raise TypeError("Objects to be merged must have the same dict type for data") | |
| if "data" not in kwargs: | |
| new_data = data_type() | |
| for k in objs[0].data.keys(): | |
| new_data[k] = kwargs["wrapper"].row_stack_arrs(*[obj.data[k] for obj in objs], group_by=False) | |
| kwargs["data"] = new_data | |
| kwargs["data"] = cls.fix_data_dict_type(kwargs["data"]) | |
| for attr in cls._key_dict_attrs: | |
| if attr not in kwargs: | |
| attr_data_type = None | |
| for obj in objs: | |
| v = getattr(obj, attr) | |
| if attr_data_type is None: | |
| attr_data_type = type(v) | |
| elif not isinstance(v, attr_data_type): | |
| raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'") | |
| kwargs[attr] = getattr(objs[-1], attr) | |
| kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) | |
| kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) | |
| kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs) | |
| return cls(**kwargs) | |
| @hybrid_method | |
| def column_stack( | |
| cls_or_self: tp.MaybeType[DataT], | |
| *objs: tp.MaybeTuple[DataT], | |
| wrapper_kwargs: tp.KwargsLike = None, | |
| **kwargs, | |
| ) -> DataT: | |
| """Stack multiple `Data` instances along columns. | |
| Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.""" | |
| if not isinstance(cls_or_self, type): | |
| objs = (cls_or_self, *objs) | |
| cls = type(cls_or_self) | |
| else: | |
| cls = cls_or_self | |
| if len(objs) == 1: | |
| objs = objs[0] | |
| objs = list(objs) | |
| for obj in objs: | |
| if not checks.is_instance_of(obj, Data): | |
| raise TypeError("Each object to be merged must be an instance of Data") | |
| if "wrapper" not in kwargs: | |
| if wrapper_kwargs is None: | |
| wrapper_kwargs = {} | |
| kwargs["wrapper"] = ArrayWrapper.column_stack( | |
| *[obj.wrapper for obj in objs], | |
| **wrapper_kwargs, | |
| ) | |
| keys = set() | |
| for obj in objs: | |
| keys = keys.union(set(obj.data.keys())) | |
| data_type = None | |
| for obj in objs: | |
| if len(keys.difference(set(obj.data.keys()))) > 0: | |
| if isinstance(obj.data, feature_dict): | |
| raise ValueError("Objects to be merged must have the same features") | |
| else: | |
| raise ValueError("Objects to be merged must have the same symbols") | |
| if data_type is None: | |
| data_type = type(obj.data) | |
| elif not isinstance(obj.data, data_type): | |
| raise TypeError("Objects to be merged must have the same dict type for data") | |
| if "data" not in kwargs: | |
| new_data = data_type() | |
| for k in objs[0].data.keys(): | |
| new_data[k] = kwargs["wrapper"].column_stack_arrs(*[obj.data[k] for obj in objs], group_by=False) | |
| kwargs["data"] = new_data | |
| kwargs["data"] = cls.fix_data_dict_type(kwargs["data"]) | |
| for attr in cls._key_dict_attrs: | |
| if attr not in kwargs: | |
| attr_data_type = None | |
| for obj in objs: | |
| v = getattr(obj, attr) | |
| if attr_data_type is None: | |
| attr_data_type = type(v) | |
| elif not isinstance(v, attr_data_type): | |
| raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'") | |
| if (issubclass(data_type, feature_dict) and issubclass(attr_data_type, symbol_dict)) or ( | |
| issubclass(data_type, symbol_dict) and issubclass(attr_data_type, feature_dict) | |
| ): | |
| kwargs[attr] = attr_data_type() | |
| for obj in objs: | |
| kwargs[attr].update(**getattr(obj, attr)) | |
| kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) | |
| kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) | |
| kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs) | |
| return cls(**kwargs) | |
| def __init__( | |
| self, | |
| wrapper: ArrayWrapper, | |
| data: tp.Union[feature_dict, symbol_dict], | |
| single_key: bool = True, | |
| classes: tp.Union[None, feature_dict, symbol_dict] = None, | |
| level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, | |
| fetch_kwargs: tp.Union[None, feature_dict, symbol_dict] = None, | |
| returned_kwargs: tp.Union[None, feature_dict, symbol_dict] = None, | |
| last_index: tp.Union[None, feature_dict, symbol_dict] = None, | |
| delisted: tp.Union[None, feature_dict, symbol_dict] = None, | |
| tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None, | |
| tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None, | |
| missing_index: tp.Optional[str] = None, | |
| missing_columns: tp.Optional[str] = None, | |
| **kwargs, | |
| ) -> None: | |
| Analyzable.__init__( | |
| self, | |
| wrapper, | |
| data=data, | |
| single_key=single_key, | |
| classes=classes, | |
| level_name=level_name, | |
| fetch_kwargs=fetch_kwargs, | |
| returned_kwargs=returned_kwargs, | |
| last_index=last_index, | |
| delisted=delisted, | |
| tz_localize=tz_localize, | |
| tz_convert=tz_convert, | |
| missing_index=missing_index, | |
| missing_columns=missing_columns, | |
| **kwargs, | |
| ) | |
| if len(set(map(self.prepare_key, data.keys()))) < len(list(map(self.prepare_key, data.keys()))): | |
| raise ValueError("Found duplicate keys in data dictionary") | |
| data = self.fix_data_dict_type(data) | |
| for obj in data.values(): | |
| checks.assert_meta_equal(obj, data[list(data.keys())[0]]) | |
| if len(data) > 1: | |
| single_key = False | |
| self._data = data | |
| self._single_key = single_key | |
| self._level_name = level_name | |
| self._tz_localize = tz_localize | |
| self._tz_convert = tz_convert | |
| self._missing_index = missing_index | |
| self._missing_columns = missing_columns | |
| attr_kwargs = dict() | |
| for attr in self._key_dict_attrs: | |
| attr_value = locals()[attr] | |
| attr_kwargs[attr] = attr_value | |
| attr_kwargs = self.fix_dict_types_in_kwargs(type(data), **attr_kwargs) | |
| for k, v in attr_kwargs.items(): | |
| setattr(self, "_" + k, v) | |
| # Copy writeable attrs | |
| self._feature_config = type(self)._feature_config.copy() | |
| def replace(self: DataT, **kwargs) -> DataT: | |
| """See `vectorbtpro.utils.config.Configured.replace`. | |
| Replaces the data's index and/or columns if they were changed in the wrapper.""" | |
| if "wrapper" in kwargs and "data" not in kwargs: | |
| wrapper = kwargs["wrapper"] | |
| if isinstance(wrapper, dict): | |
| new_index = wrapper.get("index", self.wrapper.index) | |
| new_columns = wrapper.get("columns", self.wrapper.columns) | |
| else: | |
| new_index = wrapper.index | |
| new_columns = wrapper.columns | |
| data = self.config["data"] | |
| new_data = {} | |
| index_changed = False | |
| columns_changed = False | |
| for k, v in data.items(): | |
| if isinstance(v, (pd.Series, pd.DataFrame)): | |
| if not checks.is_index_equal(v.index, new_index): | |
| v = v.copy(deep=False) | |
| v.index = new_index | |
| index_changed = True | |
| if isinstance(v, pd.DataFrame): | |
| if not checks.is_index_equal(v.columns, new_columns): | |
| v = v.copy(deep=False) | |
| v.columns = new_columns | |
| columns_changed = True | |
| new_data[k] = v | |
| if index_changed or columns_changed: | |
| kwargs["data"] = self.fix_data_dict_type(new_data) | |
| if columns_changed: | |
| rename = dict(zip(self.keys, new_columns)) | |
| for attr in self._key_dict_attrs: | |
| if attr not in kwargs: | |
| attr_value = getattr(self, attr) | |
| if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or ( | |
| self.symbol_oriented and isinstance(attr_value, feature_dict) | |
| ): | |
| kwargs[attr] = self.rename_in_dict(getattr(self, attr), rename) | |
| kwargs = self.fix_dict_types_in_kwargs(type(kwargs.get("data", self.data)), **kwargs) | |
| return Analyzable.replace(self, **kwargs) | |
| def indexing_func(self: DataT, *args, replace_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: | |
| """Perform indexing on `Data`.""" | |
| if replace_kwargs is None: | |
| replace_kwargs = {} | |
| wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) | |
| new_wrapper = wrapper_meta["new_wrapper"] | |
| new_data = self.dict_type() | |
| for k, v in self._data.items(): | |
| if wrapper_meta["rows_changed"]: | |
| v = v.iloc[wrapper_meta["row_idxs"]] | |
| if wrapper_meta["columns_changed"]: | |
| v = v.iloc[:, wrapper_meta["col_idxs"]] | |
| new_data[k] = v | |
| attr_dicts = dict() | |
| attr_dicts["last_index"] = type(self.last_index)() | |
| for k in self.last_index: | |
| attr_dicts["last_index"][k] = min([self.last_index[k], new_wrapper.index[-1]]) | |
| if wrapper_meta["columns_changed"]: | |
| new_symbols = new_wrapper.columns | |
| for attr in self._key_dict_attrs: | |
| attr_value = getattr(self, attr) | |
| if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or ( | |
| self.symbol_oriented and isinstance(attr_value, feature_dict) | |
| ): | |
| if attr in attr_dicts: | |
| attr_dicts[attr] = self.select_from_dict(attr_dicts[attr], new_symbols) | |
| else: | |
| attr_dicts[attr] = self.select_from_dict(attr_value, new_symbols) | |
| return self.replace(wrapper=new_wrapper, data=new_data, **attr_dicts, **replace_kwargs) | |
| @property | |
| def data(self) -> tp.Union[feature_dict, symbol_dict]: | |
| """Data dictionary. | |
| Has the type `feature_dict` for feature-oriented data or `symbol_dict` for symbol-oriented data.""" | |
| return self._data | |
| @property | |
| def dict_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]: | |
| """Return the dict type.""" | |
| return type(self.data) | |
| @property | |
| def column_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]: | |
| """Return the column type.""" | |
| if isinstance(self.data, feature_dict): | |
| return symbol_dict | |
| return feature_dict | |
| @property | |
| def feature_oriented(self) -> bool: | |
| """Whether data has features as keys.""" | |
| return issubclass(self.dict_type, feature_dict) | |
| @property | |
| def symbol_oriented(self) -> bool: | |
| """Whether data has symbols as keys.""" | |
| return issubclass(self.dict_type, symbol_dict) | |
| def get_keys(self, dict_type: tp.Type[tp.Union[feature_dict, symbol_dict]]) -> tp.List[tp.Key]: | |
| """Get keys depending on the provided dict type.""" | |
| checks.assert_subclass_of(dict_type, (feature_dict, symbol_dict), arg_name="dict_type") | |
| if issubclass(dict_type, feature_dict): | |
| return self.features | |
| return self.symbols | |
| @property | |
| def keys(self) -> tp.List[tp.Union[tp.Feature, tp.Symbol]]: | |
| """Keys in data. | |
| Features if `feature_dict` and symbols if `symbol_dict`.""" | |
| return list(self.data.keys()) | |
| @property | |
| def single_key(self) -> bool: | |
| """Whether there is only one key in `Data.data`.""" | |
| return self._single_key | |
| @property | |
| def single_feature(self) -> bool: | |
| """Whether there is only one feature in `Data.data`.""" | |
| if self.feature_oriented: | |
| return self.single_key | |
| return self.wrapper.ndim == 1 | |
| @property | |
| def single_symbol(self) -> bool: | |
| """Whether there is only one symbol in `Data.data`.""" | |
| if self.symbol_oriented: | |
| return self.single_key | |
| return self.wrapper.ndim == 1 | |
| @property | |
| def classes(self) -> tp.Union[feature_dict, symbol_dict]: | |
| """Key classes.""" | |
| return self._classes | |
| @property | |
| def feature_classes(self) -> tp.Optional[feature_dict]: | |
| """Feature classes.""" | |
| if self.feature_oriented: | |
| return self.classes | |
| return None | |
| @property | |
| def symbol_classes(self) -> tp.Optional[symbol_dict]: | |
| """Symbol classes.""" | |
| if self.symbol_oriented: | |
| return self.classes | |
| return None | |
| @hybrid_method | |
| def get_level_name( | |
| cls_or_self, | |
| keys: tp.Optional[tp.Keys] = None, | |
| level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, | |
| feature_oriented: tp.Optional[bool] = None, | |
| ) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]: | |
| """Get level name(s) for keys.""" | |
| if isinstance(cls_or_self, type): | |
| checks.assert_not_none(keys, arg_name="keys") | |
| checks.assert_not_none(feature_oriented, arg_name="feature_oriented") | |
| else: | |
| if keys is None: | |
| keys = cls_or_self.keys | |
| if level_name is None: | |
| level_name = cls_or_self._level_name | |
| if feature_oriented is None: | |
| feature_oriented = cls_or_self.feature_oriented | |
| first_key = keys[0] | |
| if isinstance(level_name, bool): | |
| if level_name: | |
| level_name = None | |
| else: | |
| return None | |
| if feature_oriented: | |
| key_prefix = "feature" | |
| else: | |
| key_prefix = "symbol" | |
| if isinstance(first_key, tuple): | |
| if level_name is None: | |
| level_name = ["%s_%d" % (key_prefix, i) for i in range(len(first_key))] | |
| if not checks.is_iterable(level_name) or isinstance(level_name, str): | |
| raise TypeError("Level name should be list-like for a MultiIndex") | |
| return tuple(level_name) | |
| if level_name is None: | |
| level_name = key_prefix | |
| return level_name | |
| @property | |
| def level_name(self) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]: | |
| """Level name(s) for keys. | |
| Keys are symbols or features depending on the data dict type. | |
| Must be a sequence if keys are tuples, otherwise a hashable. | |
| If False, no level names will be used.""" | |
| return self.get_level_name() | |
| @hybrid_method | |
| def get_key_index( | |
| cls_or_self, | |
| keys: tp.Optional[tp.Keys] = None, | |
| level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, | |
| feature_oriented: tp.Optional[bool] = None, | |
| ) -> tp.Index: | |
| """Get key index.""" | |
| if isinstance(cls_or_self, type): | |
| checks.assert_not_none(keys, arg_name="keys") | |
| else: | |
| if keys is None: | |
| keys = cls_or_self.keys | |
| level_name = cls_or_self.get_level_name(keys=keys, level_name=level_name, feature_oriented=feature_oriented) | |
| if isinstance(level_name, tuple): | |
| return pd.MultiIndex.from_tuples(keys, names=level_name) | |
| return pd.Index(keys, name=level_name) | |
| @property | |
| def key_index(self) -> tp.Index: | |
| """Key index.""" | |
| return self.get_key_index() | |
| @property | |
| def fetch_kwargs(self) -> tp.Union[feature_dict, symbol_dict]: | |
| """Keyword arguments of type `symbol_dict` initially passed to `Data.fetch_symbol`.""" | |
| return self._fetch_kwargs | |
| @property | |
| def returned_kwargs(self) -> tp.Union[feature_dict, symbol_dict]: | |
| """Keyword arguments of type `symbol_dict` returned by `Data.fetch_symbol`.""" | |
| return self._returned_kwargs | |
| @property | |
| def last_index(self) -> tp.Union[feature_dict, symbol_dict]: | |
| """Last fetched index per symbol of type `symbol_dict`.""" | |
| return self._last_index | |
| @property | |
| def delisted(self) -> tp.Union[feature_dict, symbol_dict]: | |
| """Delisted flag per symbol of type `symbol_dict`.""" | |
| return self._delisted | |
| @property | |
| def tz_localize(self) -> tp.Union[None, bool, tp.TimezoneLike]: | |
| """Timezone to localize a datetime-naive index to, which is initially passed to `Data.pull`.""" | |
| return self._tz_localize | |
| @property | |
| def tz_convert(self) -> tp.Union[None, bool, tp.TimezoneLike]: | |
| """Timezone to convert a datetime-aware to, which is initially passed to `Data.pull`.""" | |
| return self._tz_convert | |
| @property | |
| def missing_index(self) -> tp.Optional[str]: | |
| """Argument `missing` passed to `Data.align_index`.""" | |
| return self._missing_index | |
| @property | |
| def missing_columns(self) -> tp.Optional[str]: | |
| """Argument `missing` passed to `Data.align_columns`.""" | |
| return self._missing_columns | |
| # ############# Settings ############# # | |
| @classmethod | |
| def get_base_settings(cls, *args, **kwargs) -> dict: | |
| """`CustomData.get_settings` with `path_id="base"`.""" | |
| return cls.get_settings(*args, path_id="base", **kwargs) | |
| @classmethod | |
| def has_base_settings(cls, *args, **kwargs) -> bool: | |
| """`CustomData.has_settings` with `path_id="base"`.""" | |
| return cls.has_settings(*args, path_id="base", **kwargs) | |
| @classmethod | |
| def get_base_setting(cls, *args, **kwargs) -> tp.Any: | |
| """`CustomData.get_setting` with `path_id="base"`.""" | |
| return cls.get_setting(*args, path_id="base", **kwargs) | |
| @classmethod | |
| def has_base_setting(cls, *args, **kwargs) -> bool: | |
| """`CustomData.has_setting` with `path_id="base"`.""" | |
| return cls.has_setting(*args, path_id="base", **kwargs) | |
| @classmethod | |
| def resolve_base_setting(cls, *args, **kwargs) -> tp.Any: | |
| """`CustomData.resolve_setting` with `path_id="base"`.""" | |
| return cls.resolve_setting(*args, path_id="base", **kwargs) | |
| @classmethod | |
| def set_base_settings(cls, *args, **kwargs) -> None: | |
| """`CustomData.set_settings` with `path_id="base"`.""" | |
| cls.set_settings(*args, path_id="base", **kwargs) | |
| # ############# Iteration ############# # | |
| def items( | |
| self, | |
| over: str = "symbols", | |
| group_by: tp.GroupByLike = None, | |
| apply_group_by: bool = False, | |
| keep_2d: bool = False, | |
| key_as_index: bool = False, | |
| ) -> tp.Items: | |
| """Iterate over columns (or groups if grouped and `Wrapping.group_select` is True), keys, | |
| features, or symbols. The respective mode can be selected with `over`. | |
| See `vectorbtpro.base.wrapping.Wrapping.items` for iteration over columns. | |
| Iteration over keys supports `group_by` but doesn't support `apply_group_by`.""" | |
| if ( | |
| over.lower() == "columns" | |
| or (over.lower() == "symbols" and self.feature_oriented) | |
| or (over.lower() == "features" and self.symbol_oriented) | |
| ): | |
| for k, v in Analyzable.items( | |
| self, | |
| group_by=group_by, | |
| apply_group_by=apply_group_by, | |
| keep_2d=keep_2d, | |
| key_as_index=key_as_index, | |
| ): | |
| yield k, v | |
| elif ( | |
| over.lower() == "keys" | |
| or (over.lower() == "features" and self.feature_oriented) | |
| or (over.lower() == "symbols" and self.symbol_oriented) | |
| ): | |
| if apply_group_by: | |
| raise ValueError("Cannot apply grouping to keys") | |
| if group_by is not None: | |
| key_wrapper = self.get_key_wrapper(group_by=group_by) | |
| if key_wrapper.get_ndim() == 1: | |
| if key_as_index: | |
| yield key_wrapper.get_columns(), self | |
| else: | |
| yield key_wrapper.get_columns()[0], self | |
| else: | |
| for group, group_idxs in key_wrapper.grouper.iter_groups(key_as_index=key_as_index): | |
| if keep_2d or len(group_idxs) > 1: | |
| yield group, self.select_keys([self.keys[i] for i in group_idxs]) | |
| else: | |
| yield group, self.select_keys(self.keys[group_idxs[0]]) | |
| else: | |
| key_wrapper = self.get_key_wrapper(attach_classes=False) | |
| if key_wrapper.ndim == 1: | |
| if key_as_index: | |
| yield key_wrapper.columns, self | |
| else: | |
| yield key_wrapper.columns[0], self | |
| else: | |
| for i in range(len(key_wrapper.columns)): | |
| if key_as_index: | |
| key = key_wrapper.columns[[i]] | |
| else: | |
| key = key_wrapper.columns[i] | |
| if keep_2d: | |
| yield key, self.select_keys([key]) | |
| else: | |
| yield key, self.select_keys(key) | |
| else: | |
| raise ValueError(f"Invalid over: '{over}'") | |
| # ############# Getting ############# # | |
| def get_key_wrapper( | |
| self, | |
| keys: tp.Optional[tp.MaybeKeys] = None, | |
| attach_classes: bool = True, | |
| clean_index_kwargs: tp.KwargsLike = None, | |
| group_by: tp.GroupByLike = None, | |
| **kwargs, | |
| ) -> ArrayWrapper: | |
| """Get wrapper with keys as columns. | |
| If `attach_classes` is True, attaches `Data.classes` by stacking them over | |
| the keys using `vectorbtpro.base.indexes.stack_indexes`. | |
| Other keyword arguments are passed to the constructor of the wrapper.""" | |
| if clean_index_kwargs is None: | |
| clean_index_kwargs = {} | |
| if keys is None: | |
| keys = self.keys | |
| ndim = 1 if self.single_key else 2 | |
| else: | |
| if self.has_multiple_keys(keys): | |
| ndim = 2 | |
| else: | |
| keys = [keys] | |
| ndim = 1 | |
| for key in keys: | |
| if self.feature_oriented: | |
| self.assert_has_feature(key) | |
| else: | |
| self.assert_has_symbol(key) | |
| new_columns = self.get_key_index(keys=keys) | |
| wrapper = self.wrapper.replace( | |
| columns=new_columns, | |
| ndim=ndim, | |
| grouper=None, | |
| **kwargs, | |
| ) | |
| if attach_classes: | |
| classes = [] | |
| all_have_classes = True | |
| for key in wrapper.columns: | |
| if key in self.classes: | |
| key_classes = self.classes[key] | |
| if len(key_classes) > 0: | |
| classes.append(key_classes) | |
| else: | |
| all_have_classes = False | |
| else: | |
| all_have_classes = False | |
| if len(classes) > 0 and not all_have_classes: | |
| if self.feature_oriented: | |
| raise ValueError("Some features have classes while others not") | |
| else: | |
| raise ValueError("Some symbols have classes while others not") | |
| if len(classes) > 0: | |
| classes_frame = pd.DataFrame(classes) | |
| if len(classes_frame.columns) == 1: | |
| classes_columns = pd.Index(classes_frame.iloc[:, 0]) | |
| else: | |
| classes_columns = pd.MultiIndex.from_frame(classes_frame) | |
| new_columns = stack_indexes((classes_columns, wrapper.columns), **clean_index_kwargs) | |
| wrapper = wrapper.replace(columns=new_columns) | |
| if group_by is not None: | |
| wrapper = wrapper.replace(group_by=group_by) | |
| return wrapper | |
| @cached_property | |
| def key_wrapper(self) -> ArrayWrapper: | |
| """Key wrapper.""" | |
| return self.get_key_wrapper() | |
| def get_feature_wrapper(self, features: tp.Optional[tp.MaybeFeatures] = None, **kwargs) -> ArrayWrapper: | |
| """Get wrapper with features as columns.""" | |
| if self.feature_oriented: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment