Skip to content

Instantly share code, notes, and snippets.

@tommy-ca
Created April 15, 2025 01:39
Show Gist options
  • Select an option

  • Save tommy-ca/c9ee54b469332f5bc3ba9f4ec5265601 to your computer and use it in GitHub Desktop.

Select an option

Save tommy-ca/c9ee54b469332f5bc3ba9f4ec5265601 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file is a merged representation of the entire codebase, combined into a single document by Repomix.
<file_summary>
This section contains a summary of this file.
<purpose>
This file contains a packed representation of the entire repository's contents.
It is designed to be easily consumable by AI systems for analysis, code review,
or other automated processes.
</purpose>
<file_format>
The content is organized as follows:
1. This summary section
2. Repository information
3. Directory structure
4. Repository files, each consisting of:
- File path as an attribute
- Full contents of the file
</file_format>
<usage_guidelines>
- This file should be treated as read-only. Any changes should be made to the
original repository files, not this packed version.
- When processing this file, use the file path to distinguish
between different files in the repository.
- Be aware that this file may contain sensitive information. Handle it with
the same level of security as you would the original repository.
</usage_guidelines>
<notes>
- Some files may have been excluded based on .gitignore rules and Repomix's configuration
- Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files
- Files matching patterns in .gitignore are excluded
- Files matching default ignore patterns are excluded
- Files are sorted by Git change count (files with more changes are at the bottom)
</notes>
<additional_info>
</additional_info>
</file_summary>
<directory_structure>
base/
grouping/
__init__.py
base.py
nb.py
resampling/
__init__.py
base.py
nb.py
__init__.py
accessors.py
chunking.py
combining.py
decorators.py
flex_indexing.py
indexes.py
indexing.py
merging.py
preparing.py
reshaping.py
wrapping.py
data/
custom/
__init__.py
alpaca.py
av.py
bento.py
binance.py
ccxt.py
csv.py
custom.py
db.py
duckdb.py
feather.py
file.py
finpy.py
gbm_ohlc.py
gbm.py
hdf.py
local.py
ndl.py
parquet.py
polygon.py
random_ohlc.py
random.py
remote.py
sql.py
synthetic.py
tv.py
yf.py
__init__.py
base.py
decorators.py
nb.py
saver.py
updater.py
generic/
nb/
__init__.py
apply_reduce.py
base.py
iter_.py
patterns.py
records.py
rolling.py
sim_range.py
splitting/
__init__.py
base.py
decorators.py
nb.py
purged.py
sklearn_.py
__init__.py
accessors.py
analyzable.py
decorators.py
drawdowns.py
enums.py
plots_builder.py
plotting.py
price_records.py
ranges.py
sim_range.py
stats_builder.py
indicators/
custom/
__init__.py
adx.py
atr.py
bbands.py
hurst.py
ma.py
macd.py
msd.py
obv.py
ols.py
patsim.py
pivotinfo.py
rsi.py
sigdet.py
stoch.py
supertrend.py
vwap.py
__init__.py
configs.py
enums.py
expr.py
factory.py
nb.py
talib_.py
labels/
generators/
__init__.py
bolb.py
fixlb.py
fmax.py
fmean.py
fmin.py
fstd.py
meanlb.py
pivotlb.py
trendlb.py
__init__.py
enums.py
nb.py
ohlcv/
__init__.py
accessors.py
enums.py
nb.py
portfolio/
nb/
__init__.py
analysis.py
core.py
ctx_helpers.py
from_order_func.py
from_orders.py
from_signals.py
iter_.py
records.py
pfopt/
__init__.py
base.py
nb.py
records.py
__init__.py
base.py
call_seq.py
chunking.py
decorators.py
enums.py
logs.py
orders.py
preparing.py
trades.py
px/
__init__.py
accessors.py
decorators.py
records/
__init__.py
base.py
chunking.py
col_mapper.py
decorators.py
mapped_array.py
nb.py
registries/
__init__.py
ca_registry.py
ch_registry.py
jit_registry.py
pbar_registry.py
returns/
__init__.py
accessors.py
enums.py
nb.py
qs_adapter.py
signals/
generators/
__init__.py
ohlcstcx.py
ohlcstx.py
rand.py
randnx.py
randx.py
rprob.py
rprobcx.py
rprobnx.py
rprobx.py
stcx.py
stx.py
__init__.py
accessors.py
enums.py
factory.py
nb.py
templates/
dark.json
light.json
seaborn.json
utils/
knowledge/
__init__.py
asset_pipelines.py
base_asset_funcs.py
base_assets.py
chatting.py
custom_asset_funcs.py
custom_assets.py
formatting.py
__init__.py
annotations.py
array_.py
attr_.py
base.py
caching.py
chaining.py
checks.py
chunking.py
colors.py
config.py
cutting.py
datetime_.py
datetime_nb.py
decorators.py
enum_.py
eval_.py
execution.py
figure.py
formatting.py
hashing.py
image_.py
jitting.py
magic_decorators.py
mapping.py
math_.py
merging.py
module_.py
params.py
parsing.py
path_.py
pbar.py
pickling.py
profiling.py
random_.py
requests_.py
schedule_.py
search_.py
selection.py
tagging.py
telegram.py
template.py
warnings_.py
__init__.py
_dtypes.py
_opt_deps.py
_settings.py
_typing.py
_version.py
accessors.py
</directory_structure>
<files>
This section contains the contents of the repository's files.
<file path="base/grouping/__init__.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with classes and utilities for grouping."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.base.grouping.base import *
from vectorbtpro.base.grouping.nb import *
</file>
<file path="base/grouping/base.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base classes and functions for grouping.
Class `Grouper` stores metadata related to grouping index. It can return, for example,
the number of groups, the start indices of groups, and other information useful for reducing
operations that utilize grouping. It also allows to dynamically enable/disable/modify groups
and checks whether a certain operation is permitted."""
import numpy as np
import pandas as pd
from pandas.core.groupby import GroupBy as PandasGroupBy
from pandas.core.resample import Resampler as PandasResampler
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import indexes
from vectorbtpro.base.grouping import nb
from vectorbtpro.base.indexes import ExceptLevel
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.array_ import is_sorted
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.decorators import cached_method
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"Grouper",
]
GroupByT = tp.Union[None, bool, tp.Index]
GrouperT = tp.TypeVar("GrouperT", bound="Grouper")
class Grouper(Configured):
"""Class that exposes methods to group index.
`group_by` can be:
* boolean (False for no grouping, True for one group),
* integer (level by position),
* string (level by name),
* sequence of integers or strings that is shorter than `index` (multiple levels),
* any other sequence that has the same length as `index` (group per index).
Set `allow_enable` to False to prohibit grouping if `Grouper.group_by` is None.
Set `allow_disable` to False to prohibit disabling of grouping if `Grouper.group_by` is not None.
Set `allow_modify` to False to prohibit modifying groups (you can still change their labels).
All properties are read-only to enable caching."""
@classmethod
def group_by_to_index(
cls,
index: tp.Index,
group_by: tp.GroupByLike,
def_lvl_name: tp.Hashable = "group",
) -> GroupByT:
"""Convert mapper `group_by` to `pd.Index`.
!!! note
Index and mapper must have the same length."""
if group_by is None or group_by is False:
return group_by
if isinstance(group_by, CustomTemplate):
group_by = group_by.substitute(context=dict(index=index), strict=True, eval_id="group_by")
if group_by is True:
group_by = pd.Index(["group"] * len(index), name=def_lvl_name)
elif isinstance(index, pd.MultiIndex) or isinstance(group_by, (ExceptLevel, int, str)):
if isinstance(group_by, ExceptLevel):
except_levels = group_by.value
if isinstance(except_levels, (int, str)):
except_levels = [except_levels]
new_group_by = []
for i, name in enumerate(index.names):
if i not in except_levels and name not in except_levels:
new_group_by.append(name)
if len(new_group_by) == 0:
group_by = pd.Index(["group"] * len(index), name=def_lvl_name)
else:
if len(new_group_by) == 1:
new_group_by = new_group_by[0]
group_by = indexes.select_levels(index, new_group_by)
elif isinstance(group_by, (int, str)):
group_by = indexes.select_levels(index, group_by)
elif (
isinstance(group_by, (tuple, list))
and not isinstance(group_by[0], pd.Index)
and len(group_by) <= len(index.names)
):
try:
group_by = indexes.select_levels(index, group_by)
except (IndexError, KeyError):
pass
if not isinstance(group_by, pd.Index):
if isinstance(group_by[0], pd.Index):
group_by = pd.MultiIndex.from_arrays(group_by)
else:
group_by = pd.Index(group_by, name=def_lvl_name)
if len(group_by) != len(index):
raise ValueError("group_by and index must have the same length")
return group_by
@classmethod
def group_by_to_groups_and_index(
cls,
index: tp.Index,
group_by: tp.GroupByLike,
def_lvl_name: tp.Hashable = "group",
) -> tp.Tuple[tp.Array1d, tp.Index]:
"""Return array of group indices pointing to the original index, and grouped index."""
if group_by is None or group_by is False:
return np.arange(len(index)), index
group_by = cls.group_by_to_index(index, group_by, def_lvl_name)
codes, uniques = pd.factorize(group_by)
if not isinstance(uniques, pd.Index):
new_index = pd.Index(uniques)
else:
new_index = uniques
if isinstance(group_by, pd.MultiIndex):
new_index.names = group_by.names
elif isinstance(group_by, (pd.Index, pd.Series)):
new_index.name = group_by.name
return codes, new_index
@classmethod
def iter_group_lens(cls, group_lens: tp.GroupLens) -> tp.Iterator[tp.GroupIdxs]:
"""Iterate over indices of each group in group lengths."""
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in range(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
yield np.arange(from_col, to_col)
@classmethod
def iter_group_map(cls, group_map: tp.GroupMap) -> tp.Iterator[tp.GroupIdxs]:
"""Iterate over indices of each group in a group map."""
group_idxs, group_lens = group_map
group_start = 0
group_end = 0
for group in range(len(group_lens)):
group_len = group_lens[group]
group_end += group_len
yield group_idxs[group_start:group_end]
group_start += group_len
@classmethod
def from_pd_group_by(
cls: tp.Type[GrouperT],
pd_group_by: tp.PandasGroupByLike,
**kwargs,
) -> GrouperT:
"""Build a `Grouper` instance from a pandas `GroupBy` object.
Indices are stored under `index` and group labels under `group_by`."""
from vectorbtpro.base.merging import concat_arrays
if not isinstance(pd_group_by, (PandasGroupBy, PandasResampler)):
raise TypeError("pd_group_by must be an instance of GroupBy or Resampler")
indices = list(pd_group_by.indices.values())
group_lens = np.asarray(list(map(len, indices)))
groups = np.full(int(np.sum(group_lens)), 0, dtype=int_)
group_start_idxs = np.cumsum(group_lens)[1:] - group_lens[1:]
groups[group_start_idxs] = 1
groups = np.cumsum(groups)
index = pd.Index(concat_arrays(indices))
group_by = pd.Index(list(pd_group_by.indices.keys()), name="group")[groups]
return cls(
index=index,
group_by=group_by,
**kwargs,
)
def __init__(
self,
index: tp.Index,
group_by: tp.GroupByLike = None,
def_lvl_name: tp.Hashable = "group",
allow_enable: bool = True,
allow_disable: bool = True,
allow_modify: bool = True,
**kwargs,
) -> None:
if not isinstance(index, pd.Index):
index = pd.Index(index)
if group_by is None or group_by is False:
group_by = None
else:
group_by = self.group_by_to_index(index, group_by, def_lvl_name=def_lvl_name)
self._index = index
self._group_by = group_by
self._def_lvl_name = def_lvl_name
self._allow_enable = allow_enable
self._allow_disable = allow_disable
self._allow_modify = allow_modify
Configured.__init__(
self,
index=index,
group_by=group_by,
def_lvl_name=def_lvl_name,
allow_enable=allow_enable,
allow_disable=allow_disable,
allow_modify=allow_modify,
**kwargs,
)
@property
def index(self) -> tp.Index:
"""Original index."""
return self._index
@property
def group_by(self) -> GroupByT:
"""Mapper for grouping."""
return self._group_by
@property
def def_lvl_name(self) -> tp.Hashable:
"""Default level name."""
return self._def_lvl_name
@property
def allow_enable(self) -> bool:
"""Whether to allow enabling grouping."""
return self._allow_enable
@property
def allow_disable(self) -> bool:
"""Whether to allow disabling grouping."""
return self._allow_disable
@property
def allow_modify(self) -> bool:
"""Whether to allow changing groups."""
return self._allow_modify
def is_grouped(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether index are grouped."""
if group_by is False:
return False
if group_by is None:
group_by = self.group_by
return group_by is not None
def is_grouping_enabled(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been enabled."""
return self.group_by is None and self.is_grouped(group_by=group_by)
def is_grouping_disabled(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been disabled."""
return self.group_by is not None and not self.is_grouped(group_by=group_by)
@cached_method(whitelist=True)
def is_grouping_modified(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been modified.
Doesn't care if grouping labels have been changed."""
if group_by is None or (group_by is False and self.group_by is None):
return False
group_by = self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name)
if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index):
if not pd.Index.equals(group_by, self.group_by):
groups1 = self.group_by_to_groups_and_index(
self.index,
group_by,
def_lvl_name=self.def_lvl_name,
)[0]
groups2 = self.group_by_to_groups_and_index(
self.index,
self.group_by,
def_lvl_name=self.def_lvl_name,
)[0]
if not np.array_equal(groups1, groups2):
return True
return False
return True
@cached_method(whitelist=True)
def is_grouping_changed(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been changed in any way."""
if group_by is None or (group_by is False and self.group_by is None):
return False
if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index):
if pd.Index.equals(group_by, self.group_by):
return False
return True
def is_group_count_changed(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether the number of groups has changed."""
if group_by is None or (group_by is False and self.group_by is None):
return False
if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index):
return len(group_by) != len(self.group_by)
return True
def check_group_by(
self,
group_by: tp.GroupByLike = None,
allow_enable: tp.Optional[bool] = None,
allow_disable: tp.Optional[bool] = None,
allow_modify: tp.Optional[bool] = None,
) -> None:
"""Check passed `group_by` object against restrictions."""
if allow_enable is None:
allow_enable = self.allow_enable
if allow_disable is None:
allow_disable = self.allow_disable
if allow_modify is None:
allow_modify = self.allow_modify
if self.is_grouping_enabled(group_by=group_by):
if not allow_enable:
raise ValueError("Enabling grouping is not allowed")
elif self.is_grouping_disabled(group_by=group_by):
if not allow_disable:
raise ValueError("Disabling grouping is not allowed")
elif self.is_grouping_modified(group_by=group_by):
if not allow_modify:
raise ValueError("Modifying groups is not allowed")
def resolve_group_by(self, group_by: tp.GroupByLike = None, **kwargs) -> GroupByT:
"""Resolve `group_by` from either object variable or keyword argument."""
if group_by is None:
group_by = self.group_by
if group_by is False and self.group_by is None:
group_by = None
self.check_group_by(group_by=group_by, **kwargs)
return self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name)
@cached_method(whitelist=True)
def get_groups_and_index(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.Tuple[tp.Array1d, tp.Index]:
"""See `Grouper.group_by_to_groups_and_index`."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
return self.group_by_to_groups_and_index(self.index, group_by, def_lvl_name=self.def_lvl_name)
def get_groups(self, **kwargs) -> tp.Array1d:
"""Return groups array."""
return self.get_groups_and_index(**kwargs)[0]
def get_index(self, **kwargs) -> tp.Index:
"""Return grouped index."""
return self.get_groups_and_index(**kwargs)[1]
get_grouped_index = get_index
@property
def grouped_index(self) -> tp.Index:
"""Grouped index."""
return self.get_grouped_index()
def get_stretched_index(self, **kwargs) -> tp.Index:
"""Return stretched index."""
groups, index = self.get_groups_and_index(**kwargs)
return index[groups]
def get_group_count(self, **kwargs) -> int:
"""Get number of groups."""
return len(self.get_index(**kwargs))
@cached_method(whitelist=True)
def is_sorted(self, group_by: tp.GroupByLike = None, **kwargs) -> bool:
"""Return whether groups are monolithic, sorted."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
groups = self.get_groups(group_by=group_by)
return is_sorted(groups)
@cached_method(whitelist=True)
def get_group_lens(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupLens:
"""See `vectorbtpro.base.grouping.nb.get_group_lens_nb`."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
if group_by is None or group_by is False: # no grouping
return np.full(len(self.index), 1)
if not self.is_sorted(group_by=group_by):
raise ValueError("group_by must form monolithic, sorted groups")
groups = self.get_groups(group_by=group_by)
func = jit_reg.resolve_option(nb.get_group_lens_nb, jitted)
return func(groups)
def get_group_start_idxs(self, **kwargs) -> tp.Array1d:
"""Get first index of each group as an array."""
group_lens = self.get_group_lens(**kwargs)
return np.cumsum(group_lens) - group_lens
def get_group_end_idxs(self, **kwargs) -> tp.Array1d:
"""Get end index of each group as an array."""
group_lens = self.get_group_lens(**kwargs)
return np.cumsum(group_lens)
@cached_method(whitelist=True)
def get_group_map(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupMap:
"""See get_group_map_nb."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
if group_by is None or group_by is False: # no grouping
return np.arange(len(self.index)), np.full(len(self.index), 1)
groups, new_index = self.get_groups_and_index(group_by=group_by)
func = jit_reg.resolve_option(nb.get_group_map_nb, jitted)
return func(groups, len(new_index))
def iter_group_idxs(self, **kwargs) -> tp.Iterator[tp.GroupIdxs]:
"""Iterate over indices of each group."""
group_map = self.get_group_map(**kwargs)
return self.iter_group_map(group_map)
def iter_groups(
self,
key_as_index: bool = False,
**kwargs,
) -> tp.Iterator[tp.Tuple[tp.Union[tp.Hashable, pd.Index], tp.GroupIdxs]]:
"""Iterate over groups and their indices."""
index = self.get_index(**kwargs)
for group, group_idxs in enumerate(self.iter_group_idxs(**kwargs)):
if key_as_index:
yield index[[group]], group_idxs
else:
yield index[group], group_idxs
def select_groups(self, group_idxs: tp.Array1d, jitted: tp.JittedOption = None) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Select groups.
Returns indices and new group array. Automatically decides whether to use group lengths or group map."""
from vectorbtpro.base.reshaping import to_1d_array
if self.is_sorted():
func = jit_reg.resolve_option(nb.group_lens_select_nb, jitted)
new_group_idxs, new_groups = func(self.get_group_lens(), to_1d_array(group_idxs)) # faster
else:
func = jit_reg.resolve_option(nb.group_map_select_nb, jitted)
new_group_idxs, new_groups = func(self.get_group_map(), to_1d_array(group_idxs)) # more flexible
return new_group_idxs, new_groups
</file>
<file path="base/grouping/nb.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for grouping."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = []
GroupByT = tp.Union[None, bool, tp.Index]
@register_jitted(cache=True)
def get_group_lens_nb(groups: tp.Array1d) -> tp.GroupLens:
"""Return the count per group.
!!! note
Columns must form monolithic, sorted groups. For unsorted groups, use `get_group_map_nb`."""
result = np.empty(groups.shape[0], dtype=int_)
j = 0
last_group = -1
group_len = 0
for i in range(groups.shape[0]):
cur_group = groups[i]
if cur_group < last_group:
raise ValueError("Columns must form monolithic, sorted groups")
if cur_group != last_group:
if last_group != -1:
# Process previous group
result[j] = group_len
j += 1
group_len = 0
last_group = cur_group
group_len += 1
if i == groups.shape[0] - 1:
# Process last group
result[j] = group_len
j += 1
group_len = 0
return result[:j]
@register_jitted(cache=True)
def get_group_map_nb(groups: tp.Array1d, n_groups: int) -> tp.GroupMap:
"""Build the map between groups and indices.
Returns an array with indices segmented by group and an array with group lengths.
Works well for unsorted group arrays."""
group_lens_out = np.full(n_groups, 0, dtype=int_)
for g in range(groups.shape[0]):
group = groups[g]
group_lens_out[group] += 1
group_start_idxs = np.cumsum(group_lens_out) - group_lens_out
group_idxs_out = np.empty((groups.shape[0],), dtype=int_)
group_i = np.full(n_groups, 0, dtype=int_)
for g in range(groups.shape[0]):
group = groups[g]
group_idxs_out[group_start_idxs[group] + group_i[group]] = g
group_i[group] += 1
return group_idxs_out, group_lens_out
@register_jitted(cache=True)
def group_lens_select_nb(group_lens: tp.GroupLens, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Perform indexing on a sorted array using group lengths.
Returns indices of elements corresponding to groups in `new_groups` and a new group array."""
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
n_values = np.sum(group_lens[new_groups])
indices_out = np.empty(n_values, dtype=int_)
group_arr_out = np.empty(n_values, dtype=int_)
j = 0
for c in range(new_groups.shape[0]):
from_r = group_start_idxs[new_groups[c]]
to_r = group_end_idxs[new_groups[c]]
if from_r == to_r:
continue
rang = np.arange(from_r, to_r)
indices_out[j : j + rang.shape[0]] = rang
group_arr_out[j : j + rang.shape[0]] = c
j += rang.shape[0]
return indices_out, group_arr_out
@register_jitted(cache=True)
def group_map_select_nb(group_map: tp.GroupMap, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Perform indexing using group map."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
total_count = np.sum(group_lens[new_groups])
indices_out = np.empty(total_count, dtype=int_)
group_arr_out = np.empty(total_count, dtype=int_)
j = 0
for new_group_i in range(len(new_groups)):
new_group = new_groups[new_group_i]
group_len = group_lens[new_group]
if group_len == 0:
continue
group_start_idx = group_start_idxs[new_group]
idxs = group_idxs[group_start_idx : group_start_idx + group_len]
indices_out[j : j + group_len] = idxs
group_arr_out[j : j + group_len] = new_group_i
j += group_len
return indices_out, group_arr_out
@register_jitted(cache=True)
def group_by_evenly_nb(n: int, n_splits: int) -> tp.Array1d:
"""Get `group_by` from evenly splitting a space of values."""
out = np.empty(n, dtype=int_)
for i in range(n):
out[i] = i * n_splits // n + n_splits // (2 * n)
return out
</file>
<file path="base/resampling/__init__.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with classes and utilities for resampling."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.base.resampling.base import *
from vectorbtpro.base.resampling.nb import *
</file>
<file path="base/resampling/base.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base classes and functions for resampling."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.resampling import nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.decorators import cached_property, hybrid_method
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"Resampler",
]
ResamplerT = tp.TypeVar("ResamplerT", bound="Resampler")
class Resampler(Configured):
"""Class that exposes methods to resample index.
Args:
source_index (index_like): Index being resampled.
target_index (index_like): Index resulted from resampling.
source_freq (frequency_like or bool): Frequency or date offset of the source index.
Set to False to force-set the frequency to None.
target_freq (frequency_like or bool): Frequency or date offset of the target index.
Set to False to force-set the frequency to None.
silence_warnings (bool): Whether to silence all warnings."""
def __init__(
self,
source_index: tp.IndexLike,
target_index: tp.IndexLike,
source_freq: tp.Union[None, bool, tp.FrequencyLike] = None,
target_freq: tp.Union[None, bool, tp.FrequencyLike] = None,
silence_warnings: tp.Optional[bool] = None,
**kwargs,
) -> None:
source_index = dt.prepare_dt_index(source_index)
target_index = dt.prepare_dt_index(target_index)
infer_source_freq = True
if isinstance(source_freq, bool):
if not source_freq:
infer_source_freq = False
source_freq = None
infer_target_freq = True
if isinstance(target_freq, bool):
if not target_freq:
infer_target_freq = False
target_freq = None
if infer_source_freq:
source_freq = dt.infer_index_freq(source_index, freq=source_freq)
if infer_target_freq:
target_freq = dt.infer_index_freq(target_index, freq=target_freq)
self._source_index = source_index
self._target_index = target_index
self._source_freq = source_freq
self._target_freq = target_freq
self._silence_warnings = silence_warnings
Configured.__init__(
self,
source_index=source_index,
target_index=target_index,
source_freq=source_freq,
target_freq=target_freq,
silence_warnings=silence_warnings,
**kwargs,
)
@classmethod
def from_pd_resampler(
cls: tp.Type[ResamplerT],
pd_resampler: tp.PandasResampler,
source_freq: tp.Optional[tp.FrequencyLike] = None,
silence_warnings: bool = True,
) -> ResamplerT:
"""Build `Resampler` from
[pandas.core.resample.Resampler](https://pandas.pydata.org/docs/reference/resampling.html).
"""
target_index = pd_resampler.count().index
return cls(
source_index=pd_resampler.obj.index,
target_index=target_index,
source_freq=source_freq,
target_freq=None,
silence_warnings=silence_warnings,
)
@classmethod
def from_pd_resample(
cls: tp.Type[ResamplerT],
source_index: tp.IndexLike,
*args,
source_freq: tp.Optional[tp.FrequencyLike] = None,
silence_warnings: bool = True,
**kwargs,
) -> ResamplerT:
"""Build `Resampler` from
[pandas.DataFrame.resample](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html).
"""
pd_resampler = pd.Series(index=source_index, dtype=object).resample(*args, **kwargs)
return cls.from_pd_resampler(pd_resampler, source_freq=source_freq, silence_warnings=silence_warnings)
@classmethod
def from_date_range(
cls: tp.Type[ResamplerT],
source_index: tp.IndexLike,
*args,
source_freq: tp.Optional[tp.FrequencyLike] = None,
silence_warnings: tp.Optional[bool] = None,
**kwargs,
) -> ResamplerT:
"""Build `Resampler` from `vectorbtpro.utils.datetime_.date_range`."""
target_index = dt.date_range(*args, **kwargs)
return cls(
source_index=source_index,
target_index=target_index,
source_freq=source_freq,
target_freq=None,
silence_warnings=silence_warnings,
)
@property
def source_index(self) -> tp.Index:
"""Index being resampled."""
return self._source_index
@property
def target_index(self) -> tp.Index:
"""Index resulted from resampling."""
return self._target_index
@property
def source_freq(self) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the source index."""
return self._source_freq
@property
def target_freq(self) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the target index."""
return self._target_freq
@property
def silence_warnings(self) -> bool:
"""Frequency or date offset of the target index."""
from vectorbtpro._settings import settings
resampling_cfg = settings["resampling"]
silence_warnings = self._silence_warnings
if silence_warnings is None:
silence_warnings = resampling_cfg["silence_warnings"]
return silence_warnings
def get_np_source_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the source index in NumPy format."""
if silence_warnings is None:
silence_warnings = self.silence_warnings
warned = False
source_freq = self.source_freq
if source_freq is not None:
if not isinstance(source_freq, (int, float)):
try:
source_freq = dt.to_timedelta64(source_freq)
except ValueError as e:
if not silence_warnings:
warn(f"Cannot convert {source_freq} to np.timedelta64. Setting to None.")
warned = True
source_freq = None
if source_freq is None:
if not warned and not silence_warnings:
warn("Using right bound of source index without frequency. Set source frequency.")
return source_freq
def get_np_target_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the target index in NumPy format."""
if silence_warnings is None:
silence_warnings = self.silence_warnings
warned = False
target_freq = self.target_freq
if target_freq is not None:
if not isinstance(target_freq, (int, float)):
try:
target_freq = dt.to_timedelta64(target_freq)
except ValueError as e:
if not silence_warnings:
warn(f"Cannot convert {target_freq} to np.timedelta64. Setting to None.")
warned = True
target_freq = None
if target_freq is None:
if not warned and not silence_warnings:
warn("Using right bound of target index without frequency. Set target frequency.")
return target_freq
@classmethod
def get_lbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index:
"""Get the left bound of a datetime index.
If `freq` is None, calculates the leftmost bound."""
index = dt.prepare_dt_index(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if freq is not None:
return index.shift(-1, freq=freq) + pd.Timedelta(1, "ns")
min_ts = pd.DatetimeIndex([pd.Timestamp.min.tz_localize(index.tz)])
return (index[:-1] + pd.Timedelta(1, "ns")).append(min_ts)
@classmethod
def get_rbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index:
"""Get the right bound of a datetime index.
If `freq` is None, calculates the rightmost bound."""
index = dt.prepare_dt_index(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if freq is not None:
return index.shift(1, freq=freq) - pd.Timedelta(1, "ns")
max_ts = pd.DatetimeIndex([pd.Timestamp.max.tz_localize(index.tz)])
return (index[1:] - pd.Timedelta(1, "ns")).append(max_ts)
@cached_property
def source_lbound_index(self) -> tp.Index:
"""Get the left bound of the source datetime index."""
return self.get_lbound_index(self.source_index, freq=self.source_freq)
@cached_property
def source_rbound_index(self) -> tp.Index:
"""Get the right bound of the source datetime index."""
return self.get_rbound_index(self.source_index, freq=self.source_freq)
@cached_property
def target_lbound_index(self) -> tp.Index:
"""Get the left bound of the target datetime index."""
return self.get_lbound_index(self.target_index, freq=self.target_freq)
@cached_property
def target_rbound_index(self) -> tp.Index:
"""Get the right bound of the target datetime index."""
return self.get_rbound_index(self.target_index, freq=self.target_freq)
def map_to_target_index(
self,
before: bool = False,
raise_missing: bool = True,
return_index: bool = True,
jitted: tp.JittedOption = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Union[tp.Array1d, tp.Index]:
"""See `vectorbtpro.base.resampling.nb.map_to_target_index_nb`."""
target_freq = self.get_np_target_freq(silence_warnings=silence_warnings)
func = jit_reg.resolve_option(nb.map_to_target_index_nb, jitted)
mapped_arr = func(
self.source_index.values,
self.target_index.values,
target_freq=target_freq,
before=before,
raise_missing=raise_missing,
)
if return_index:
nan_mask = mapped_arr == -1
if nan_mask.any():
mapped_index = self.source_index.to_series().copy()
mapped_index[nan_mask] = np.nan
mapped_index[~nan_mask] = self.target_index[mapped_arr]
mapped_index = pd.Index(mapped_index)
else:
mapped_index = self.target_index[mapped_arr]
return mapped_index
return mapped_arr
def index_difference(
self,
reverse: bool = False,
return_index: bool = True,
jitted: tp.JittedOption = None,
) -> tp.Union[tp.Array1d, tp.Index]:
"""See `vectorbtpro.base.resampling.nb.index_difference_nb`."""
func = jit_reg.resolve_option(nb.index_difference_nb, jitted)
if reverse:
mapped_arr = func(self.target_index.values, self.source_index.values)
else:
mapped_arr = func(self.source_index.values, self.target_index.values)
if return_index:
return self.target_index[mapped_arr]
return mapped_arr
def map_index_to_source_ranges(
self,
before: bool = False,
jitted: tp.JittedOption = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.resampling.nb.map_index_to_source_ranges_nb`.
If `Resampler.target_freq` is a date offset, sets is to None and gives a warning.
Raises another warning is `target_freq` is None."""
target_freq = self.get_np_target_freq(silence_warnings=silence_warnings)
func = jit_reg.resolve_option(nb.map_index_to_source_ranges_nb, jitted)
return func(
self.source_index.values,
self.target_index.values,
target_freq=target_freq,
before=before,
)
@hybrid_method
def map_bounds_to_source_ranges(
cls_or_self,
source_index: tp.Optional[tp.IndexLike] = None,
target_lbound_index: tp.Optional[tp.IndexLike] = None,
target_rbound_index: tp.Optional[tp.IndexLike] = None,
closed_lbound: bool = True,
closed_rbound: bool = False,
skip_not_found: bool = False,
jitted: tp.JittedOption = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.resampling.nb.map_bounds_to_source_ranges_nb`.
Either `target_lbound_index` or `target_rbound_index` must be set.
Set `target_lbound_index` and `target_rbound_index` to 'pandas' to use
`Resampler.get_lbound_index` and `Resampler.get_rbound_index` respectively.
Also, both allow providing a single datetime string and will automatically broadcast
to the `Resampler.target_index`."""
if not isinstance(cls_or_self, type):
if target_lbound_index is None and target_rbound_index is None:
raise ValueError("Either target_lbound_index or target_rbound_index must be set")
if target_lbound_index is not None:
if isinstance(target_lbound_index, str) and target_lbound_index.lower() == "pandas":
target_lbound_index = cls_or_self.target_lbound_index
else:
target_lbound_index = dt.prepare_dt_index(target_lbound_index)
target_rbound_index = cls_or_self.target_index
if target_rbound_index is not None:
target_lbound_index = cls_or_self.target_index
if isinstance(target_rbound_index, str) and target_rbound_index.lower() == "pandas":
target_rbound_index = cls_or_self.target_rbound_index
else:
target_rbound_index = dt.prepare_dt_index(target_rbound_index)
if len(target_lbound_index) == 1 and len(target_rbound_index) > 1:
target_lbound_index = repeat_index(target_lbound_index, len(target_rbound_index))
elif len(target_lbound_index) > 1 and len(target_rbound_index) == 1:
target_rbound_index = repeat_index(target_rbound_index, len(target_lbound_index))
else:
source_index = dt.prepare_dt_index(source_index)
target_lbound_index = dt.prepare_dt_index(target_lbound_index)
target_rbound_index = dt.prepare_dt_index(target_rbound_index)
checks.assert_len_equal(target_rbound_index, target_lbound_index)
func = jit_reg.resolve_option(nb.map_bounds_to_source_ranges_nb, jitted)
return func(
source_index.values,
target_lbound_index.values,
target_rbound_index.values,
closed_lbound=closed_lbound,
closed_rbound=closed_rbound,
skip_not_found=skip_not_found,
)
def resample_source_mask(
self,
source_mask: tp.ArrayLike,
jitted: tp.JittedOption = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Array1d:
"""See `vectorbtpro.base.resampling.nb.resample_source_mask_nb`."""
from vectorbtpro.base.reshaping import broadcast_array_to
if silence_warnings is None:
silence_warnings = self.silence_warnings
source_mask = broadcast_array_to(source_mask, len(self.source_index))
source_freq = self.get_np_source_freq(silence_warnings=silence_warnings)
target_freq = self.get_np_target_freq(silence_warnings=silence_warnings)
func = jit_reg.resolve_option(nb.resample_source_mask_nb, jitted)
return func(
source_mask,
self.source_index.values,
self.target_index.values,
source_freq,
target_freq,
)
def last_before_target_index(self, incl_source: bool = True, jitted: tp.JittedOption = None) -> tp.Array1d:
"""See `vectorbtpro.base.resampling.nb.last_before_target_index_nb`."""
func = jit_reg.resolve_option(nb.last_before_target_index_nb, jitted)
return func(self.source_index.values, self.target_index.values, incl_source=incl_source)
</file>
<file path="base/resampling/nb.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for resampling."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils.datetime_nb import d_td
__all__ = []
@register_jitted(cache=True)
def date_range_nb(
start: np.datetime64,
end: np.datetime64,
freq: np.timedelta64 = d_td,
incl_left: bool = True,
incl_right: bool = True,
) -> tp.Array1d:
"""Generate a datetime index with nanosecond precision from a date range.
Inspired by [pandas.date_range](https://pandas.pydata.org/docs/reference/api/pandas.date_range.html)."""
values_len = int(np.floor((end - start) / freq)) + 1
values = np.empty(values_len, dtype="datetime64[ns]")
for i in range(values_len):
values[i] = start + i * freq
if start == end:
if not incl_left and not incl_right:
values = values[1:-1]
else:
if not incl_left or not incl_right:
if not incl_left and len(values) and values[0] == start:
values = values[1:]
if not incl_right and len(values) and values[-1] == end:
values = values[:-1]
return values
@register_jitted(cache=True)
def map_to_target_index_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
target_freq: tp.Optional[tp.Scalar] = None,
before: bool = False,
raise_missing: bool = True,
) -> tp.Array1d:
"""Get the index of each from `source_index` in `target_index`.
If `before` is True, applied on elements that come before and including that index.
Otherwise, applied on elements that come after and including that index.
If `raise_missing` is True, will throw an error if an index cannot be mapped.
Otherwise, the element for that index becomes -1."""
out = np.empty(len(source_index), dtype=int_)
from_j = 0
for i in range(len(source_index)):
if i > 0 and source_index[i] < source_index[i - 1]:
raise ValueError("Source index must be increasing")
if i > 0 and source_index[i] == source_index[i - 1]:
out[i] = out[i - 1]
found = False
for j in range(from_j, len(target_index)):
if j > 0 and target_index[j] <= target_index[j - 1]:
raise ValueError("Target index must be strictly increasing")
if target_freq is None:
if before and source_index[i] <= target_index[j]:
if j == 0 or target_index[j - 1] < source_index[i]:
out[i] = from_j = j
found = True
break
if not before and target_index[j] <= source_index[i]:
if j == len(target_index) - 1 or source_index[i] < target_index[j + 1]:
out[i] = from_j = j
found = True
break
else:
if before and target_index[j] - target_freq < source_index[i] <= target_index[j]:
out[i] = from_j = j
found = True
break
if not before and target_index[j] <= source_index[i] < target_index[j] + target_freq:
out[i] = from_j = j
found = True
break
if not found:
if raise_missing:
raise ValueError("Resampling failed: cannot map some source indices")
out[i] = -1
return out
@register_jitted(cache=True)
def index_difference_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
) -> tp.Array1d:
"""Get the elements in `source_index` not present in `target_index`."""
out = np.empty(len(source_index), dtype=int_)
from_j = 0
k = 0
for i in range(len(source_index)):
if i > 0 and source_index[i] <= source_index[i - 1]:
raise ValueError("Array index must be strictly increasing")
found = False
for j in range(from_j, len(target_index)):
if j > 0 and target_index[j] <= target_index[j - 1]:
raise ValueError("Target index must be strictly increasing")
if source_index[i] < target_index[j]:
break
if source_index[i] == target_index[j]:
from_j = j
found = True
break
from_j = j
if not found:
out[k] = i
k += 1
return out[:k]
@register_jitted(cache=True)
def map_index_to_source_ranges_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
target_freq: tp.Optional[tp.Scalar] = None,
before: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Get the source bounds that correspond to each target index.
If `target_freq` is not None, the right bound is limited by the frequency in `target_freq`.
Otherwise, the right bound is the next index in `target_index`.
Returns a 2-dim array where the first column is the absolute start index (including) and
the second column is the absolute end index (excluding).
If an element cannot be mapped, the start and end of the range becomes -1.
!!! note
Both index arrays must be increasing. Repeating values are allowed."""
range_starts_out = np.empty(len(target_index), dtype=int_)
range_ends_out = np.empty(len(target_index), dtype=int_)
to_j = 0
for i in range(len(target_index)):
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
from_j = -1
for j in range(to_j, len(source_index)):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Array index must be increasing")
found = False
if target_freq is None:
if before:
if i == 0 and source_index[j] <= target_index[i]:
found = True
elif i > 0 and target_index[i - 1] < source_index[j] <= target_index[i]:
found = True
elif source_index[j] > target_index[i]:
break
else:
if i == len(target_index) - 1 and target_index[i] <= source_index[j]:
found = True
elif i < len(target_index) - 1 and target_index[i] <= source_index[j] < target_index[i + 1]:
found = True
elif i < len(target_index) - 1 and source_index[j] >= target_index[i + 1]:
break
else:
if before:
if target_index[i] - target_freq < source_index[j] <= target_index[i]:
found = True
elif source_index[j] > target_index[i]:
break
else:
if target_index[i] <= source_index[j] < target_index[i] + target_freq:
found = True
elif source_index[j] >= target_index[i] + target_freq:
break
if found:
if from_j == -1:
from_j = j
to_j = j + 1
if from_j == -1:
range_starts_out[i] = -1
range_ends_out[i] = -1
else:
range_starts_out[i] = from_j
range_ends_out[i] = to_j
return range_starts_out, range_ends_out
@register_jitted(cache=True)
def map_bounds_to_source_ranges_nb(
source_index: tp.Array1d,
target_lbound_index: tp.Array1d,
target_rbound_index: tp.Array1d,
closed_lbound: bool = True,
closed_rbound: bool = False,
skip_not_found: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Get the source bounds that correspond to the target bounds.
Returns a 2-dim array where the first column is the absolute start index (including) nad
the second column is the absolute end index (excluding).
If an element cannot be mapped, the start and end of the range becomes -1.
!!! note
Both index arrays must be increasing. Repeating values are allowed."""
range_starts_out = np.empty(len(target_lbound_index), dtype=int_)
range_ends_out = np.empty(len(target_lbound_index), dtype=int_)
k = 0
to_j = 0
for i in range(len(target_lbound_index)):
if i > 0 and target_lbound_index[i] < target_lbound_index[i - 1]:
raise ValueError("Target left-bound index must be increasing")
if i > 0 and target_rbound_index[i] < target_rbound_index[i - 1]:
raise ValueError("Target right-bound index must be increasing")
from_j = -1
for j in range(len(source_index)):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Array index must be increasing")
found = False
if closed_lbound and closed_rbound:
if target_lbound_index[i] <= source_index[j] <= target_rbound_index[i]:
found = True
elif source_index[j] > target_rbound_index[i]:
break
elif closed_lbound:
if target_lbound_index[i] <= source_index[j] < target_rbound_index[i]:
found = True
elif source_index[j] >= target_rbound_index[i]:
break
elif closed_rbound:
if target_lbound_index[i] < source_index[j] <= target_rbound_index[i]:
found = True
elif source_index[j] > target_rbound_index[i]:
break
else:
if target_lbound_index[i] < source_index[j] < target_rbound_index[i]:
found = True
elif source_index[j] >= target_rbound_index[i]:
break
if found:
if from_j == -1:
from_j = j
to_j = j + 1
if skip_not_found:
if from_j != -1:
range_starts_out[k] = from_j
range_ends_out[k] = to_j
k += 1
else:
if from_j == -1:
range_starts_out[i] = -1
range_ends_out[i] = -1
else:
range_starts_out[i] = from_j
range_ends_out[i] = to_j
if skip_not_found:
return range_starts_out[:k], range_ends_out[:k]
return range_starts_out, range_ends_out
@register_jitted(cache=True)
def resample_source_mask_nb(
source_mask: tp.Array1d,
source_index: tp.Array1d,
target_index: tp.Array1d,
source_freq: tp.Optional[tp.Scalar] = None,
target_freq: tp.Optional[tp.Scalar] = None,
) -> tp.Array1d:
"""Resample a source mask to the target index.
Becomes True only if the target bar is fully contained in the source bar. The source bar
is represented by a non-interrupting sequence of True values in the source mask."""
out = np.full(len(target_index), False, dtype=np.bool_)
from_j = 0
for i in range(len(target_index)):
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
target_lbound = target_index[i]
if target_freq is None:
if i + 1 < len(target_index):
target_rbound = target_index[i + 1]
else:
target_rbound = None
else:
target_rbound = target_index[i] + target_freq
found_start = False
for j in range(from_j, len(source_index)):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Source index must be increasing")
source_lbound = source_index[j]
if source_freq is None:
if j + 1 < len(source_index):
source_rbound = source_index[j + 1]
else:
source_rbound = None
else:
source_rbound = source_index[j] + source_freq
if target_rbound is not None and target_rbound <= source_lbound:
break
if found_start or (
target_lbound >= source_lbound and (source_rbound is None or target_lbound < source_rbound)
):
if not found_start:
from_j = j
found_start = True
if not source_mask[j]:
break
if source_rbound is None or (target_rbound is not None and target_rbound <= source_rbound):
out[i] = True
break
return out
@register_jitted(cache=True)
def last_before_target_index_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
incl_source: bool = True,
incl_target: bool = False,
) -> tp.Array1d:
"""For each source index, find the position of the last source index between the original
source index and the corresponding target index."""
out = np.empty(len(source_index), dtype=int_)
last_j = -1
for i in range(len(source_index)):
if i > 0 and source_index[i] < source_index[i - 1]:
raise ValueError("Source index must be increasing")
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
if source_index[i] > target_index[i]:
raise ValueError("Target index must be equal to or greater than source index")
if last_j == -1:
from_i = i + 1
else:
from_i = last_j
if incl_source:
last_j = i
else:
last_j = -1
for j in range(from_i, len(source_index)):
if source_index[j] < target_index[i]:
last_j = j
elif incl_target and source_index[j] == target_index[i]:
last_j = j
else:
break
out[i] = last_j
return out
</file>
<file path="base/__init__.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with base classes and utilities for pandas objects, such as broadcasting."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.base.grouping import *
from vectorbtpro.base.resampling import *
from vectorbtpro.base.accessors import *
from vectorbtpro.base.chunking import *
from vectorbtpro.base.combining import *
from vectorbtpro.base.decorators import *
from vectorbtpro.base.flex_indexing import *
from vectorbtpro.base.indexes import *
from vectorbtpro.base.indexing import *
from vectorbtpro.base.merging import *
from vectorbtpro.base.preparing import *
from vectorbtpro.base.reshaping import *
from vectorbtpro.base.wrapping import *
</file>
<file path="base/accessors.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom Pandas accessors for base operations with Pandas objects."""
import ast
import inspect
import numpy as np
import pandas as pd
from pandas.api.types import is_scalar
from pandas.core.groupby import GroupBy as PandasGroupBy
from pandas.core.resample import Resampler as PandasResampler
from vectorbtpro import _typing as tp
from vectorbtpro.base import combining, reshaping, indexes
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.base.indexes import IndexApplier
from vectorbtpro.base.indexing import (
point_idxr_defaults,
range_idxr_defaults,
get_index_points,
get_index_ranges,
)
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer
from vectorbtpro.utils.config import merge_dicts, resolve_dict, Configured
from vectorbtpro.utils.decorators import hybrid_property, hybrid_method
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.eval_ import evaluate
from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods
from vectorbtpro.utils.parsing import get_context_vars, get_func_arg_names
from vectorbtpro.utils.template import substitute_templates
from vectorbtpro.utils.warnings_ import warn
if tp.TYPE_CHECKING:
from vectorbtpro.data.base import Data as DataT
else:
DataT = "Data"
if tp.TYPE_CHECKING:
from vectorbtpro.generic.splitting.base import Splitter as SplitterT
else:
SplitterT = "Splitter"
__all__ = ["BaseIDXAccessor", "BaseAccessor", "BaseSRAccessor", "BaseDFAccessor"]
BaseIDXAccessorT = tp.TypeVar("BaseIDXAccessorT", bound="BaseIDXAccessor")
class BaseIDXAccessor(Configured, IndexApplier):
"""Accessor on top of Index.
Accessible via `pd.Index.vbt` and all child accessors."""
def __init__(self, obj: tp.Index, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs) -> None:
checks.assert_instance_of(obj, pd.Index)
Configured.__init__(self, obj=obj, freq=freq, **kwargs)
self._obj = obj
self._freq = freq
@property
def obj(self) -> tp.Index:
"""Pandas object."""
return self._obj
def get(self) -> tp.Index:
"""Get `IDXAccessor.obj`."""
return self.obj
# ############# Index ############# #
def to_ns(self) -> tp.Array1d:
"""Convert index to an 64-bit integer array.
Timestamps will be converted to nanoseconds."""
return dt.to_ns(self.obj)
def to_period(self, freq: tp.FrequencyLike, shift: bool = False) -> pd.PeriodIndex:
"""Convert index to period."""
index = self.obj
if isinstance(index, pd.DatetimeIndex):
index = index.tz_localize(None).to_period(freq)
if shift:
index = index.shift()
if not isinstance(index, pd.PeriodIndex):
raise TypeError(f"Cannot convert index of type {type(index)} to period")
return index
def to_period_ts(self, *args, **kwargs) -> pd.DatetimeIndex:
"""Convert index to period and then to timestamp."""
new_index = self.to_period(*args, **kwargs).to_timestamp()
if self.obj.tz is not None:
new_index = new_index.tz_localize(self.obj.tz)
return new_index
def to_period_ns(self, *args, **kwargs) -> tp.Array1d:
"""Convert index to period and then to an 64-bit integer array.
Timestamps will be converted to nanoseconds."""
return dt.to_ns(self.to_period_ts(*args, **kwargs))
@classmethod
def from_values(cls, *args, **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.index_from_values`."""
return indexes.index_from_values(*args, **kwargs)
def repeat(self, *args, **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.repeat_index`."""
return indexes.repeat_index(self.obj, *args, **kwargs)
def tile(self, *args, **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.tile_index`."""
return indexes.tile_index(self.obj, *args, **kwargs)
@hybrid_method
def stack(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
on_top: bool = False,
**kwargs,
) -> tp.Index:
"""See `vectorbtpro.base.indexes.stack_indexes`.
Set `on_top` to True to stack the second index on top of this one."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
if on_top:
objs = (*others, cls_or_self.obj)
else:
objs = (cls_or_self.obj, *others)
return indexes.stack_indexes(*objs, **kwargs)
@hybrid_method
def combine(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
on_top: bool = False,
**kwargs,
) -> tp.Index:
"""See `vectorbtpro.base.indexes.combine_indexes`.
Set `on_top` to True to stack the second index on top of this one."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
if on_top:
objs = (*others, cls_or_self.obj)
else:
objs = (cls_or_self.obj, *others)
return indexes.combine_indexes(*objs, **kwargs)
@hybrid_method
def concat(cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.concat_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return indexes.concat_indexes(*objs, **kwargs)
def apply_to_index(
self: BaseIDXAccessorT,
apply_func: tp.Callable,
*args,
**kwargs,
) -> tp.Index:
return self.replace(obj=apply_func(self.obj, *args, **kwargs)).obj
def align_to(self, *args, **kwargs) -> tp.IndexSlice:
"""See `vectorbtpro.base.indexes.align_index_to`."""
return indexes.align_index_to(self.obj, *args, **kwargs)
@hybrid_method
def align(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
**kwargs,
) -> tp.Tuple[tp.IndexSlice, ...]:
"""See `vectorbtpro.base.indexes.align_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return indexes.align_indexes(*objs, **kwargs)
def cross_with(self, *args, **kwargs) -> tp.Tuple[tp.IndexSlice, tp.IndexSlice]:
"""See `vectorbtpro.base.indexes.cross_index_with`."""
return indexes.cross_index_with(self.obj, *args, **kwargs)
@hybrid_method
def cross(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
**kwargs,
) -> tp.Tuple[tp.IndexSlice, ...]:
"""See `vectorbtpro.base.indexes.cross_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return indexes.cross_indexes(*objs, **kwargs)
x = cross
def find_first_occurrence(self, *args, **kwargs) -> int:
"""See `vectorbtpro.base.indexes.find_first_occurrence`."""
return indexes.find_first_occurrence(self.obj, *args, **kwargs)
# ############# Frequency ############# #
@hybrid_method
def get_freq(
cls_or_self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> tp.Union[None, float, tp.PandasFrequency]:
"""Index frequency as `pd.Timedelta` or None if it cannot be converted."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.obj
if freq is None:
freq = cls_or_self._freq
else:
checks.assert_not_none(index, arg_name="index")
if freq is None:
freq = wrapping_cfg["freq"]
try:
return dt.infer_index_freq(index, freq=freq, **kwargs)
except Exception as e:
return None
@property
def freq(self) -> tp.Optional[tp.PandasFrequency]:
"""`BaseIDXAccessor.get_freq` with date offsets and integer frequencies not allowed."""
return self.get_freq(allow_offset=True, allow_numeric=False)
@property
def ns_freq(self) -> tp.Optional[int]:
"""Convert frequency to a 64-bit integer.
Timedelta will be converted to nanoseconds."""
freq = self.get_freq(allow_offset=False, allow_numeric=True)
if freq is not None:
freq = dt.to_ns(dt.to_timedelta64(freq))
return freq
@property
def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]:
"""Index frequency of any type."""
return self.get_freq()
@hybrid_method
def get_periods(cls_or_self, index: tp.Optional[tp.Index] = None) -> int:
"""Get the number of periods in the index, without taking into account its datetime-like properties."""
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.obj
else:
checks.assert_not_none(index, arg_name="index")
return len(index)
@property
def periods(self) -> int:
"""`BaseIDXAccessor.get_periods` with default arguments."""
return len(self.obj)
@hybrid_method
def get_dt_periods(
cls_or_self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.PandasFrequency] = None,
) -> float:
"""Get the number of periods in the index, taking into account its datetime-like properties."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.obj
else:
checks.assert_not_none(index, arg_name="index")
if isinstance(index, pd.DatetimeIndex):
freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=True, allow_numeric=False)
if freq is not None:
if not isinstance(freq, pd.Timedelta):
freq = dt.to_timedelta(freq, approximate=True)
return (index[-1] - index[0]) / freq + 1
if not wrapping_cfg["silence_warnings"]:
warn(
"Couldn't parse the frequency of index. Pass it as `freq` or "
"define it globally under `settings.wrapping`."
)
if checks.is_number(index[0]) and checks.is_number(index[-1]):
freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=False, allow_numeric=True)
if checks.is_number(freq):
return (index[-1] - index[0]) / freq + 1
return index[-1] - index[0] + 1
if not wrapping_cfg["silence_warnings"]:
warn("Index is neither datetime-like nor integer")
return cls_or_self.get_periods(index=index)
@property
def dt_periods(self) -> float:
"""`BaseIDXAccessor.get_dt_periods` with default arguments."""
return self.get_dt_periods()
def arr_to_timedelta(
self,
a: tp.MaybeArray,
to_pd: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Union[pd.Index, tp.MaybeArray]:
"""Convert array to duration using `BaseIDXAccessor.freq`."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
freq = self.freq
if freq is None:
if not silence_warnings:
warn(
"Couldn't parse the frequency of index. Pass it as `freq` or "
"define it globally under `settings.wrapping`."
)
return a
if not isinstance(freq, pd.Timedelta):
freq = dt.to_timedelta(freq, approximate=True)
if to_pd:
out = pd.to_timedelta(a * freq)
else:
out = a * freq
return out
# ############# Grouping ############# #
def get_grouper(self, by: tp.AnyGroupByLike, groupby_kwargs: tp.KwargsLike = None, **kwargs) -> Grouper:
"""Get an index grouper of type `vectorbtpro.base.grouping.base.Grouper`.
Argument `by` can be a grouper itself, an instance of Pandas `GroupBy`,
an instance of Pandas `Resampler`, but also any supported input to any of them
such as a frequency or an array of indices.
Keyword arguments `groupby_kwargs` are passed to the Pandas methods `groupby` and `resample`,
while `**kwargs` are passed to initialize `vectorbtpro.base.grouping.base.Grouper`."""
if groupby_kwargs is None:
groupby_kwargs = {}
if isinstance(by, Grouper):
if len(kwargs) > 0:
return by.replace(**kwargs)
return by
if isinstance(by, (PandasGroupBy, PandasResampler)):
return Grouper.from_pd_group_by(by, **kwargs)
try:
return Grouper(index=self.obj, group_by=by, **kwargs)
except Exception as e:
pass
if isinstance(self.obj, pd.DatetimeIndex):
try:
return Grouper(index=self.obj, group_by=self.to_period(dt.to_freq(by)), **kwargs)
except Exception as e:
pass
try:
pd_group_by = pd.Series(index=self.obj, dtype=object).resample(dt.to_freq(by), **groupby_kwargs)
return Grouper.from_pd_group_by(pd_group_by, **kwargs)
except Exception as e:
pass
pd_group_by = pd.Series(index=self.obj, dtype=object).groupby(by, axis=0, **groupby_kwargs)
return Grouper.from_pd_group_by(pd_group_by, **kwargs)
def get_resampler(
self,
rule: tp.AnyRuleLike,
freq: tp.Optional[tp.FrequencyLike] = None,
resample_kwargs: tp.KwargsLike = None,
return_pd_resampler: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Union[Resampler, tp.PandasResampler]:
"""Get an index resampler of type `vectorbtpro.base.resampling.base.Resampler`."""
if checks.is_frequency_like(rule):
try:
rule = dt.to_freq(rule)
is_td = True
except Exception as e:
is_td = False
if is_td:
resample_kwargs = merge_dicts(
dict(closed="left", label="left"),
resample_kwargs,
)
rule = pd.Series(index=self.obj, dtype=object).resample(rule, **resolve_dict(resample_kwargs))
if isinstance(rule, PandasResampler):
if return_pd_resampler:
return rule
if silence_warnings is None:
silence_warnings = True
rule = Resampler.from_pd_resampler(rule, source_freq=self.freq, silence_warnings=silence_warnings)
if return_pd_resampler:
raise TypeError("Cannot convert Resampler to Pandas Resampler")
if checks.is_dt_like(rule) or checks.is_iterable(rule):
rule = dt.prepare_dt_index(rule)
rule = Resampler(
source_index=self.obj,
target_index=rule,
source_freq=self.freq,
target_freq=freq,
silence_warnings=silence_warnings,
)
if isinstance(rule, Resampler):
if freq is not None:
rule = rule.replace(target_freq=freq)
return rule
raise ValueError(f"Cannot build Resampler from {rule}")
# ############# Points and ranges ############# #
def get_points(self, *args, **kwargs) -> tp.Array1d:
"""See `vectorbtpro.base.indexing.get_index_points`."""
return get_index_points(self.obj, *args, **kwargs)
def get_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.indexing.get_index_ranges`."""
return get_index_ranges(self.obj, self.any_freq, *args, **kwargs)
# ############# Splitting ############# #
def split(self, *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, **kwargs) -> tp.Any:
"""Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_take`.
!!! note
Splits Pandas object, not accessor!"""
from vectorbtpro.generic.splitting.base import Splitter
if splitter_cls is None:
splitter_cls = Splitter
return splitter_cls.split_and_take(self.obj, self.obj, *args, **kwargs)
def split_apply(
self,
apply_func: tp.Callable,
*args,
splitter_cls: tp.Optional[tp.Type[SplitterT]] = None,
**kwargs,
) -> tp.Any:
"""Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`.
!!! note
Splits Pandas object, not accessor!"""
from vectorbtpro.generic.splitting.base import Splitter, Takeable
if splitter_cls is None:
splitter_cls = Splitter
return splitter_cls.split_and_apply(self.obj, apply_func, Takeable(self.obj), *args, **kwargs)
# ############# Chunking ############# #
def chunk(
self: BaseIDXAccessorT,
min_size: tp.Optional[int] = None,
n_chunks: tp.Union[None, int, str] = None,
chunk_len: tp.Union[None, int, str] = None,
chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None,
select: bool = False,
return_chunk_meta: bool = False,
) -> tp.Iterator[tp.Union[tp.Index, tp.Tuple[ChunkMeta, tp.Index]]]:
"""Chunk this instance.
If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise.
For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`.
!!! note
Splits Pandas object, not accessor!"""
if chunk_meta is None:
chunk_meta = iter_chunk_meta(
size=len(self.obj),
min_size=min_size,
n_chunks=n_chunks,
chunk_len=chunk_len
)
for _chunk_meta in chunk_meta:
if select:
array_taker = ArraySelector()
else:
array_taker = ArraySlicer()
if return_chunk_meta:
yield _chunk_meta, array_taker.take(self.obj, _chunk_meta)
else:
yield array_taker.take(self.obj, _chunk_meta)
def chunk_apply(
self: BaseIDXAccessorT,
apply_func: tp.Union[str, tp.Callable],
*args,
chunk_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MergeableResults:
"""Chunk this instance and apply a function to each chunk.
If `apply_func` is a string, becomes the method name.
For arguments related to chunking, see `Wrapping.chunk`.
!!! note
Splits Pandas object, not accessor!"""
if isinstance(apply_func, str):
apply_func = getattr(type(self), apply_func)
if chunk_kwargs is None:
chunk_arg_names = set(get_func_arg_names(self.chunk))
chunk_kwargs = {}
for k in list(kwargs.keys()):
if k in chunk_arg_names:
chunk_kwargs[k] = kwargs.pop(k)
if execute_kwargs is None:
execute_kwargs = {}
chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs)
tasks = []
keys = []
for _chunk_meta, chunk in chunks:
tasks.append(Task(apply_func, chunk, *args, **kwargs))
keys.append(get_chunk_meta_key(_chunk_meta))
keys = pd.Index(keys, name="chunk_indices")
return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs)
BaseAccessorT = tp.TypeVar("BaseAccessorT", bound="BaseAccessor")
@attach_binary_magic_methods(lambda self, other, np_func: self.combine(other, combine_func=np_func))
@attach_unary_magic_methods(lambda self, np_func: self.apply(apply_func=np_func))
class BaseAccessor(Wrapping):
"""Accessor on top of Series and DataFrames.
Accessible via `pd.Series.vbt` and `pd.DataFrame.vbt`, and all child accessors.
Series is just a DataFrame with one column, hence to avoid defining methods exclusively for 1-dim data,
we will convert any Series to a DataFrame and perform matrix computation on it. Afterwards,
by using `BaseAccessor.wrapper`, we will convert the 2-dim output back to a Series.
`**kwargs` will be passed to `vectorbtpro.base.wrapping.ArrayWrapper`.
!!! note
When using magic methods, ensure that `.vbt` is called on the operand on the left
if the other operand is an array.
Accessors do not utilize caching.
Grouping is only supported by the methods that accept the `group_by` argument.
Usage:
* Build a symmetric matrix:
```pycon
>>> from vectorbtpro import *
>>> # vectorbtpro.base.accessors.BaseAccessor.make_symmetric
>>> pd.Series([1, 2, 3]).vbt.make_symmetric()
0 1 2
0 1.0 2.0 3.0
1 2.0 NaN NaN
2 3.0 NaN NaN
```
* Broadcast pandas objects:
```pycon
>>> sr = pd.Series([1])
>>> df = pd.DataFrame([1, 2, 3])
>>> vbt.base.reshaping.broadcast_to(sr, df)
0
0 1
1 1
2 1
>>> sr.vbt.broadcast_to(df)
0
0 1
1 1
2 1
```
* Many methods such as `BaseAccessor.broadcast` are both class and instance methods:
```pycon
>>> from vectorbtpro.base.accessors import BaseAccessor
>>> # Same as sr.vbt.broadcast(df)
>>> new_sr, new_df = BaseAccessor.broadcast(sr, df)
>>> new_sr
0
0 1
1 1
2 1
>>> new_df
0
0 1
1 2
2 3
```
* Instead of explicitly importing `BaseAccessor` or any other accessor, we can use `pd_acc` instead:
```pycon
>>> vbt.pd_acc.broadcast(sr, df)
>>> new_sr
0
0 1
1 1
2 1
>>> new_df
0
0 1
1 2
2 3
```
* `BaseAccessor` implements arithmetic (such as `+`), comparison (such as `>`) and
logical operators (such as `&`) by forwarding the operation to `BaseAccessor.combine`:
```pycon
>>> sr.vbt + df
0
0 2
1 3
2 4
```
Many interesting use cases can be implemented this way.
* For example, let's compare an array with 3 different thresholds:
```pycon
>>> df.vbt > vbt.Param(np.arange(3), name='threshold')
threshold 0 1 2
a2 b2 c2 a2 b2 c2 a2 b2 c2
x2 True True True False True True False False True
y2 True True True True True True True True True
z2 True True True True True True True True True
```
* The same using the broadcasting mechanism:
```pycon
>>> df.vbt > vbt.Param(np.arange(3), name='threshold')
threshold 0 1 2
a2 b2 c2 a2 b2 c2 a2 b2 c2
x2 True True True False True True False False True
y2 True True True True True True True True True
z2 True True True True True True True True True
```
"""
@classmethod
def resolve_row_stack_kwargs(
cls: tp.Type[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `BaseAccessor` after stacking along rows."""
if "obj" not in kwargs:
kwargs["obj"] = kwargs["wrapper"].row_stack_arrs(
*[obj.obj for obj in objs],
group_by=False,
wrap=False,
)
return kwargs
@classmethod
def resolve_column_stack_kwargs(
cls: tp.Type[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `BaseAccessor` after stacking along columns."""
if "obj" not in kwargs:
kwargs["obj"] = kwargs["wrapper"].column_stack_arrs(
*[obj.obj for obj in objs],
reindex_kwargs=reindex_kwargs,
group_by=False,
wrap=False,
)
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> BaseAccessorT:
"""Stack multiple `BaseAccessor` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, BaseAccessor):
raise TypeError("Each object to be merged must be an instance of BaseAccessor")
if wrapper_kwargs is None:
wrapper_kwargs = {}
if "wrapper" in kwargs and kwargs["wrapper"] is not None:
wrapper = kwargs["wrapper"]
if len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
else:
wrapper = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
kwargs["wrapper"] = wrapper
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
if kwargs["wrapper"].ndim == 1:
return cls.sr_accessor_cls(**kwargs)
return cls.df_accessor_cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
wrapper_kwargs: tp.KwargsLike = None,
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> BaseAccessorT:
"""Stack multiple `BaseAccessor` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, BaseAccessor):
raise TypeError("Each object to be merged must be an instance of BaseAccessor")
if wrapper_kwargs is None:
wrapper_kwargs = {}
if "wrapper" in kwargs and kwargs["wrapper"] is not None:
wrapper = kwargs["wrapper"]
if len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
else:
wrapper = ArrayWrapper.column_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
kwargs["wrapper"] = wrapper
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls.df_accessor_cls(**kwargs)
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
if len(kwargs) > 0:
wrapper_kwargs, kwargs = ArrayWrapper.extract_init_kwargs(**kwargs)
else:
wrapper_kwargs, kwargs = {}, {}
if not isinstance(wrapper, ArrayWrapper):
if obj is not None:
raise ValueError("Must either provide wrapper and object, or only object")
wrapper, obj = ArrayWrapper.from_obj(wrapper, **wrapper_kwargs), wrapper
else:
if obj is None:
raise ValueError("Must either provide wrapper and object, or only object")
if len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
Wrapping.__init__(self, wrapper, obj=obj, **kwargs)
self._obj = obj
def __call__(self: BaseAccessorT, **kwargs) -> BaseAccessorT:
"""Allows passing arguments to the initializer."""
return self.replace(**kwargs)
@hybrid_property
def sr_accessor_cls(cls_or_self) -> tp.Type["BaseSRAccessor"]:
"""Accessor class for `pd.Series`."""
return BaseSRAccessor
@hybrid_property
def df_accessor_cls(cls_or_self) -> tp.Type["BaseDFAccessor"]:
"""Accessor class for `pd.DataFrame`."""
return BaseDFAccessor
def indexing_func(self: BaseAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> BaseAccessorT:
"""Perform indexing on `BaseAccessor`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
new_obj = ArrayWrapper.select_from_flex_array(
self._obj,
row_idxs=wrapper_meta["row_idxs"],
col_idxs=wrapper_meta["col_idxs"],
rows_changed=wrapper_meta["rows_changed"],
columns_changed=wrapper_meta["columns_changed"],
)
if checks.is_series(new_obj):
return self.replace(cls_=self.sr_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj)
return self.replace(cls_=self.df_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj)
def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None:
"""Perform indexing setter on `BaseAccessor`."""
pd_indexing_setter_func(self._obj)
@property
def obj(self) -> tp.SeriesFrame:
"""Pandas object."""
if isinstance(self._obj, (pd.Series, pd.DataFrame)):
if self._obj.shape == self.wrapper.shape:
if self._obj.index is self.wrapper.index:
if isinstance(self._obj, pd.Series) and self._obj.name == self.wrapper.name:
return self._obj
if isinstance(self._obj, pd.DataFrame) and self._obj.columns is self.wrapper.columns:
return self._obj
return self.wrapper.wrap(self._obj, group_by=False)
def get(self, key: tp.Optional[tp.Hashable] = None, default: tp.Optional[tp.Any] = None) -> tp.SeriesFrame:
"""Get `BaseAccessor.obj`."""
if key is None:
return self.obj
return self.obj.get(key, default=default)
@property
def unwrapped(self) -> tp.SeriesFrame:
return self.obj
@hybrid_method
def should_wrap(cls_or_self) -> bool:
return False
@hybrid_property
def ndim(cls_or_self) -> tp.Optional[int]:
"""Number of dimensions in the object.
1 -> Series, 2 -> DataFrame."""
if isinstance(cls_or_self, type):
return None
return cls_or_self.obj.ndim
@hybrid_method
def is_series(cls_or_self) -> bool:
"""Whether the object is a Series."""
if isinstance(cls_or_self, type):
raise NotImplementedError
return isinstance(cls_or_self.obj, pd.Series)
@hybrid_method
def is_frame(cls_or_self) -> bool:
"""Whether the object is a DataFrame."""
if isinstance(cls_or_self, type):
raise NotImplementedError
return isinstance(cls_or_self.obj, pd.DataFrame)
@classmethod
def resolve_shape(cls, shape: tp.ShapeLike) -> tp.Shape:
"""Resolve shape."""
shape_2d = reshaping.to_2d_shape(shape)
try:
if cls.is_series() and shape_2d[1] > 1:
raise ValueError("Use DataFrame accessor")
except NotImplementedError:
pass
return shape_2d
# ############# Creation ############# #
@classmethod
def empty(cls, shape: tp.Shape, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame:
"""Generate an empty Series/DataFrame of shape `shape` and fill with `fill_value`."""
if not isinstance(shape, tuple) or (isinstance(shape, tuple) and len(shape) == 1):
return pd.Series(np.full(shape, fill_value), **kwargs)
return pd.DataFrame(np.full(shape, fill_value), **kwargs)
@classmethod
def empty_like(cls, other: tp.SeriesFrame, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame:
"""Generate an empty Series/DataFrame like `other` and fill with `fill_value`."""
if checks.is_series(other):
return cls.empty(other.shape, fill_value=fill_value, index=other.index, name=other.name, **kwargs)
return cls.empty(other.shape, fill_value=fill_value, index=other.index, columns=other.columns, **kwargs)
# ############# Indexes ############# #
def apply_to_index(
self: BaseAccessorT,
*args,
wrap: bool = False,
**kwargs,
) -> tp.Union[BaseAccessorT, tp.SeriesFrame]:
"""See `vectorbtpro.base.wrapping.Wrapping.apply_to_index`.
!!! note
If `wrap` is False, returns Pandas object, not accessor!"""
result = Wrapping.apply_to_index(self, *args, **kwargs)
if wrap:
return result
return result.obj
# ############# Setting ############# #
def set(
self,
value_or_func: tp.Union[tp.MaybeArray, tp.Callable],
*args,
inplace: bool = False,
columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Optional[tp.SeriesFrame]:
"""Set value at each index point using `vectorbtpro.base.indexing.get_index_points`.
If `value_or_func` is a function, selects all keyword arguments that were not passed
to the `get_index_points` method, substitutes any templates, and passes everything to the function.
As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index),
`index_point` (absolute position in the index), `wrapper`, and `obj`."""
if inplace:
obj = self.obj
else:
obj = self.obj.copy()
index_points = get_index_points(self.wrapper.index, **kwargs)
if callable(value_or_func):
func_kwargs = {k: v for k, v in kwargs.items() if k not in point_idxr_defaults}
template_context = merge_dicts(kwargs, template_context)
else:
func_kwargs = None
if callable(value_or_func):
for i in range(len(index_points)):
_template_context = merge_dicts(
dict(
i=i,
index_point=index_points[i],
index_points=index_points,
wrapper=self.wrapper,
obj=self.obj,
columns=columns,
args=args,
kwargs=kwargs,
),
template_context,
)
_func_args = substitute_templates(args, _template_context, eval_id="func_args")
_func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs")
v = value_or_func(*_func_args, **_func_kwargs)
if self.is_series() or columns is None:
obj.iloc[index_points[i]] = v
elif is_scalar(columns):
obj.iloc[index_points[i], obj.columns.get_indexer([columns])[0]] = v
else:
obj.iloc[index_points[i], obj.columns.get_indexer(columns)] = v
elif checks.is_sequence(value_or_func) and not is_scalar(value_or_func):
if self.is_series():
obj.iloc[index_points] = reshaping.to_1d_array(value_or_func)
elif columns is None:
obj.iloc[index_points] = reshaping.to_2d_array(value_or_func)
elif is_scalar(columns):
obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = reshaping.to_1d_array(value_or_func)
else:
obj.iloc[index_points, obj.columns.get_indexer(columns)] = reshaping.to_2d_array(value_or_func)
else:
if self.is_series() or columns is None:
obj.iloc[index_points] = value_or_func
elif is_scalar(columns):
obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = value_or_func
else:
obj.iloc[index_points, obj.columns.get_indexer(columns)] = value_or_func
if inplace:
return None
return obj
def set_between(
self,
value_or_func: tp.Union[tp.MaybeArray, tp.Callable],
*args,
inplace: bool = False,
columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Optional[tp.SeriesFrame]:
"""Set value at each index range using `vectorbtpro.base.indexing.get_index_ranges`.
If `value_or_func` is a function, selects all keyword arguments that were not passed
to the `get_index_points` method, substitutes any templates, and passes everything to the function.
As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index),
`index_slice` (absolute slice of the index), `wrapper`, and `obj`."""
if inplace:
obj = self.obj
else:
obj = self.obj.copy()
index_ranges = get_index_ranges(self.wrapper.index, **kwargs)
if callable(value_or_func):
func_kwargs = {k: v for k, v in kwargs.items() if k not in range_idxr_defaults}
template_context = merge_dicts(kwargs, template_context)
else:
func_kwargs = None
for i in range(len(index_ranges[0])):
if callable(value_or_func):
_template_context = merge_dicts(
dict(
i=i,
index_slice=slice(index_ranges[0][i], index_ranges[1][i]),
range_starts=index_ranges[0],
range_ends=index_ranges[1],
wrapper=self.wrapper,
obj=self.obj,
columns=columns,
args=args,
kwargs=kwargs,
),
template_context,
)
_func_args = substitute_templates(args, _template_context, eval_id="func_args")
_func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs")
v = value_or_func(*_func_args, **_func_kwargs)
elif checks.is_sequence(value_or_func) and not isinstance(value_or_func, str):
v = value_or_func[i]
else:
v = value_or_func
if self.is_series() or columns is None:
obj.iloc[index_ranges[0][i] : index_ranges[1][i]] = v
elif is_scalar(columns):
obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer([columns])[0]] = v
else:
obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer(columns)] = v
if inplace:
return None
return obj
# ############# Reshaping ############# #
def to_1d_array(self) -> tp.Array1d:
"""See `vectorbtpro.base.reshaping.to_1d` with `raw` set to True."""
return reshaping.to_1d_array(self.obj)
def to_2d_array(self) -> tp.Array2d:
"""See `vectorbtpro.base.reshaping.to_2d` with `raw` set to True."""
return reshaping.to_2d_array(self.obj)
def tile(
self,
n: int,
keys: tp.Optional[tp.IndexLike] = None,
axis: int = 1,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.base.reshaping.tile`.
Set `axis` to 1 for columns and 0 for index.
Use `keys` as the outermost level."""
tiled = reshaping.tile(self.obj, n, axis=axis)
if keys is not None:
if axis == 1:
new_columns = indexes.combine_indexes([keys, self.wrapper.columns])
return ArrayWrapper.from_obj(tiled).wrap(
tiled.values,
**merge_dicts(dict(columns=new_columns), wrap_kwargs),
)
else:
new_index = indexes.combine_indexes([keys, self.wrapper.index])
return ArrayWrapper.from_obj(tiled).wrap(
tiled.values,
**merge_dicts(dict(index=new_index), wrap_kwargs),
)
return tiled
def repeat(
self,
n: int,
keys: tp.Optional[tp.IndexLike] = None,
axis: int = 1,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.base.reshaping.repeat`.
Set `axis` to 1 for columns and 0 for index.
Use `keys` as the outermost level."""
repeated = reshaping.repeat(self.obj, n, axis=axis)
if keys is not None:
if axis == 1:
new_columns = indexes.combine_indexes([self.wrapper.columns, keys])
return ArrayWrapper.from_obj(repeated).wrap(
repeated.values,
**merge_dicts(dict(columns=new_columns), wrap_kwargs),
)
else:
new_index = indexes.combine_indexes([self.wrapper.index, keys])
return ArrayWrapper.from_obj(repeated).wrap(
repeated.values,
**merge_dicts(dict(index=new_index), wrap_kwargs),
)
return repeated
def align_to(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None, **kwargs) -> tp.SeriesFrame:
"""Align to `other` on their axes using `vectorbtpro.base.indexes.align_index_to`.
Usage:
```pycon
>>> df1 = pd.DataFrame(
... [[1, 2], [3, 4]],
... index=['x', 'y'],
... columns=['a', 'b']
... )
>>> df1
a b
x 1 2
y 3 4
>>> df2 = pd.DataFrame(
... [[5, 6, 7, 8], [9, 10, 11, 12]],
... index=['x', 'y'],
... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']])
... )
>>> df2
1 2
a b a b
x 5 6 7 8
y 9 10 11 12
>>> df1.vbt.align_to(df2)
1 2
a b a b
x 1 2 1 2
y 3 4 3 4
```
"""
checks.assert_instance_of(other, (pd.Series, pd.DataFrame))
obj = reshaping.to_2d(self.obj)
other = reshaping.to_2d(other)
aligned_index = indexes.align_index_to(obj.index, other.index, **kwargs)
aligned_columns = indexes.align_index_to(obj.columns, other.columns, **kwargs)
obj = obj.iloc[aligned_index, aligned_columns]
return self.wrapper.wrap(
obj.values,
group_by=False,
**merge_dicts(dict(index=other.index, columns=other.columns), wrap_kwargs),
)
@hybrid_method
def align(
cls_or_self,
*others: tp.Union[tp.SeriesFrame, "BaseAccessor"],
**kwargs,
) -> tp.Tuple[tp.SeriesFrame, ...]:
"""Align objects using `vectorbtpro.base.indexes.align_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
objs_2d = list(map(reshaping.to_2d, objs))
index_slices, new_index = indexes.align_indexes(
*map(lambda x: x.index, objs_2d),
return_new_index=True,
**kwargs,
)
column_slices, new_columns = indexes.align_indexes(
*map(lambda x: x.columns, objs_2d),
return_new_index=True,
**kwargs,
)
new_objs = []
for i in range(len(objs_2d)):
new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False)
if objs[i].ndim == 1 and new_obj.shape[1] == 1:
new_obj = new_obj.iloc[:, 0].rename(objs[i].name)
new_obj.index = new_index
new_obj.columns = new_columns
new_objs.append(new_obj)
return tuple(new_objs)
def cross_with(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame:
"""Align to `other` on their axes using `vectorbtpro.base.indexes.cross_index_with`.
Usage:
```pycon
>>> df1 = pd.DataFrame(
... [[1, 2, 3, 4], [5, 6, 7, 8]],
... index=['x', 'y'],
... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']])
... )
>>> df1
1 2
a b a b
x 1 2 3 4
y 5 6 7 8
>>> df2 = pd.DataFrame(
... [[9, 10, 11, 12], [13, 14, 15, 16]],
... index=['x', 'y'],
... columns=pd.MultiIndex.from_arrays([[3, 3, 4, 4], ['a', 'b', 'a', 'b']])
... )
>>> df2
3 4
a b a b
x 9 10 11 12
y 13 14 15 16
>>> df1.vbt.cross_with(df2)
1 2
3 4 3 4
a b a b a b a b
x 1 2 1 2 3 4 3 4
y 5 6 5 6 7 8 7 8
```
"""
checks.assert_instance_of(other, (pd.Series, pd.DataFrame))
obj = reshaping.to_2d(self.obj)
other = reshaping.to_2d(other)
index_slices, new_index = indexes.cross_index_with(
obj.index,
other.index,
return_new_index=True,
)
column_slices, new_columns = indexes.cross_index_with(
obj.columns,
other.columns,
return_new_index=True,
)
obj = obj.iloc[index_slices[0], column_slices[0]]
return self.wrapper.wrap(
obj.values,
group_by=False,
**merge_dicts(dict(index=new_index, columns=new_columns), wrap_kwargs),
)
@hybrid_method
def cross(cls_or_self, *others: tp.Union[tp.SeriesFrame, "BaseAccessor"]) -> tp.Tuple[tp.SeriesFrame, ...]:
"""Align objects using `vectorbtpro.base.indexes.cross_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
objs_2d = list(map(reshaping.to_2d, objs))
index_slices, new_index = indexes.cross_indexes(
*map(lambda x: x.index, objs_2d),
return_new_index=True,
)
column_slices, new_columns = indexes.cross_indexes(
*map(lambda x: x.columns, objs_2d),
return_new_index=True,
)
new_objs = []
for i in range(len(objs_2d)):
new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False)
if objs[i].ndim == 1 and new_obj.shape[1] == 1:
new_obj = new_obj.iloc[:, 0].rename(objs[i].name)
new_obj.index = new_index
new_obj.columns = new_columns
new_objs.append(new_obj)
return tuple(new_objs)
x = cross
@hybrid_method
def broadcast(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any:
"""See `vectorbtpro.base.reshaping.broadcast`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return reshaping.broadcast(*objs, **kwargs)
def broadcast_to(self, other: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any:
"""See `vectorbtpro.base.reshaping.broadcast_to`."""
if isinstance(other, BaseAccessor):
other = other.obj
return reshaping.broadcast_to(self.obj, other, **kwargs)
@hybrid_method
def broadcast_combs(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any:
"""See `vectorbtpro.base.reshaping.broadcast_combs`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return reshaping.broadcast_combs(*objs, **kwargs)
def make_symmetric(self, *args, **kwargs) -> tp.Frame:
"""See `vectorbtpro.base.reshaping.make_symmetric`."""
return reshaping.make_symmetric(self.obj, *args, **kwargs)
def unstack_to_array(self, *args, **kwargs) -> tp.Array:
"""See `vectorbtpro.base.reshaping.unstack_to_array`."""
return reshaping.unstack_to_array(self.obj, *args, **kwargs)
def unstack_to_df(self, *args, **kwargs) -> tp.Frame:
"""See `vectorbtpro.base.reshaping.unstack_to_df`."""
return reshaping.unstack_to_df(self.obj, *args, **kwargs)
def to_dict(self, *args, **kwargs) -> tp.Mapping:
"""See `vectorbtpro.base.reshaping.to_dict`."""
return reshaping.to_dict(self.obj, *args, **kwargs)
# ############# Conversion ############# #
def to_data(
self,
data_cls: tp.Optional[tp.Type[DataT]] = None,
columns_are_symbols: bool = True,
**kwargs,
) -> DataT:
"""Convert to a `vectorbtpro.data.base.Data` instance."""
if data_cls is None:
from vectorbtpro.data.base import Data
data_cls = Data
return data_cls.from_data(self.obj, columns_are_symbols=columns_are_symbols, **kwargs)
# ############# Combining ############# #
def apply(
self,
apply_func: tp.Callable,
*args,
keep_pd: bool = False,
to_2d: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Apply a function `apply_func`.
Set `keep_pd` to True to keep inputs as pandas objects, otherwise convert to NumPy arrays.
Set `to_2d` to True to reshape inputs to 2-dim arrays, otherwise keep as-is.
`*args` and `**kwargs` are passed to `apply_func`.
!!! note
The resulted array must have the same shape as the original array.
Usage:
* Using instance method:
```pycon
>>> sr = pd.Series([1, 2], index=['x', 'y'])
>>> sr.vbt.apply(lambda x: x ** 2)
x 1
y 4
dtype: int64
```
* Using class method, templates, and broadcasting:
```pycon
>>> sr.vbt.apply(
... lambda x, y: x + y,
... vbt.Rep('y'),
... broadcast_named_args=dict(
... y=pd.DataFrame([[3, 4]], columns=['a', 'b'])
... )
... )
a b
x 4 5
y 5 6
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
broadcast_named_args = {"obj": self.obj, **broadcast_named_args}
if len(broadcast_named_args) > 1:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
wrapper = self.wrapper
if to_2d:
broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()}
elif not keep_pd:
broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()}
template_context = merge_dicts(broadcast_named_args, template_context)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
out = apply_func(broadcast_named_args["obj"], *args, **kwargs)
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def concat(
cls_or_self,
*others: tp.ArrayLike,
broadcast_kwargs: tp.KwargsLike = None,
keys: tp.Optional[tp.IndexLike] = None,
) -> tp.Frame:
"""Concatenate with `others` along columns.
Usage:
```pycon
>>> sr = pd.Series([1, 2], index=['x', 'y'])
>>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
>>> sr.vbt.concat(df, keys=['c', 'd'])
c d
a b a b
x 1 1 3 4
y 2 2 5 6
```
"""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj,) + others
if broadcast_kwargs is None:
broadcast_kwargs = {}
broadcasted = reshaping.broadcast(*objs, **broadcast_kwargs)
broadcasted = tuple(map(reshaping.to_2d, broadcasted))
out = pd.concat(broadcasted, axis=1, keys=keys)
if not isinstance(out.columns, pd.MultiIndex) and np.all(out.columns == 0):
out.columns = pd.RangeIndex(start=0, stop=len(out.columns), step=1)
return out
def apply_and_concat(
self,
ntimes: int,
apply_func: tp.Callable,
*args,
keep_pd: bool = False,
to_2d: bool = False,
keys: tp.Optional[tp.IndexLike] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeTuple[tp.Frame]:
"""Apply `apply_func` `ntimes` times and concatenate the results along columns.
See `vectorbtpro.base.combining.apply_and_concat`.
`ntimes` is the number of times to call `apply_func`, while `n_outputs` is the number of outputs to expect.
`*args` and `**kwargs` are passed to `vectorbtpro.base.combining.apply_and_concat`.
!!! note
The resulted arrays to be concatenated must have the same shape as broadcast input arrays.
Usage:
* Using instance method:
```pycon
>>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
>>> df.vbt.apply_and_concat(
... 3,
... lambda i, a, b: a * b[i],
... [1, 2, 3],
... keys=['c', 'd', 'e']
... )
c d e
a b a b a b
x 3 4 6 8 9 12
y 5 6 10 12 15 18
```
* Using class method, templates, and broadcasting:
```pycon
>>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z'])
>>> sr.vbt.apply_and_concat(
... 3,
... lambda i, a, b: a * b + i,
... vbt.Rep('df'),
... broadcast_named_args=dict(
... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
apply_idx 0 1 2
a b c a b c a b c
x 1 2 3 2 3 4 3 4 5
y 2 4 6 3 5 7 4 6 8
z 3 6 9 4 7 10 5 8 11
```
* To change the execution engine or specify other engine-related arguments, use `execute_kwargs`:
```pycon
>>> import time
>>> def apply_func(i, a):
... time.sleep(1)
... return a
>>> sr = pd.Series([1, 2, 3])
>>> %timeit sr.vbt.apply_and_concat(3, apply_func)
3.02 s ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit sr.vbt.apply_and_concat(3, apply_func, execute_kwargs=dict(engine='dask'))
1.02 s ± 927 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
broadcast_named_args = {"obj": self.obj, **broadcast_named_args}
if len(broadcast_named_args) > 1:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
wrapper = self.wrapper
if to_2d:
broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()}
elif not keep_pd:
broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()}
template_context = merge_dicts(broadcast_named_args, dict(ntimes=ntimes), template_context)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
out = combining.apply_and_concat(ntimes, apply_func, broadcast_named_args["obj"], *args, **kwargs)
if keys is not None:
new_columns = indexes.combine_indexes([keys, wrapper.columns])
else:
top_columns = pd.Index(np.arange(ntimes), name="apply_idx")
new_columns = indexes.combine_indexes([top_columns, wrapper.columns])
if out is None:
return None
wrap_kwargs = merge_dicts(dict(columns=new_columns), wrap_kwargs)
if isinstance(out, list):
return tuple(map(lambda x: wrapper.wrap(x, group_by=False, **wrap_kwargs), out))
return wrapper.wrap(out, group_by=False, **wrap_kwargs)
@hybrid_method
def combine(
cls_or_self,
obj: tp.MaybeTupleList[tp.Union[tp.ArrayLike, "BaseAccessor"]],
combine_func: tp.Callable,
*args,
allow_multiple: bool = True,
keep_pd: bool = False,
to_2d: bool = False,
concat: tp.Optional[bool] = None,
keys: tp.Optional[tp.IndexLike] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Combine with `other` using `combine_func`.
Args:
obj (array_like): Object(s) to combine this array with.
combine_func (callable): Function to combine two arrays.
Can be Numba-compiled.
*args: Variable arguments passed to `combine_func`.
allow_multiple (bool): Whether a tuple/list/Index will be considered as multiple objects in `other`.
Takes effect only when using the instance method.
keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays.
to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is.
concat (bool): Whether to concatenate the results along the column axis.
Otherwise, pairwise combine into a Series/DataFrame of the same shape.
If True, see `vectorbtpro.base.combining.combine_and_concat`.
If False, see `vectorbtpro.base.combining.combine_multiple`.
If None, becomes True if there are multiple objects to combine.
Can only concatenate using the instance method.
keys (index_like): Outermost column level.
broadcast_named_args (dict): Dictionary with arguments to broadcast against each other.
broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`.
template_context (dict): Context used to substitute templates in `args` and `kwargs`.
wrap_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
**kwargs: Keyword arguments passed to `combine_func`.
!!! note
If `combine_func` is Numba-compiled, will broadcast using `WRITEABLE` and `C_CONTIGUOUS`
flags, which can lead to an expensive computation overhead if passed objects are large and
have different shape/memory order. You also must ensure that all objects have the same data type.
Also remember to bring each in `*args` to a Numba-compatible format.
Usage:
* Using instance method:
```pycon
>>> sr = pd.Series([1, 2], index=['x', 'y'])
>>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
>>> # using instance method
>>> sr.vbt.combine(df, np.add)
a b
x 4 5
y 7 8
>>> sr.vbt.combine([df, df * 2], np.add, concat=False)
a b
x 10 13
y 17 20
>>> sr.vbt.combine([df, df * 2], np.add)
combine_idx 0 1
a b a b
x 4 5 7 9
y 7 8 12 14
>>> sr.vbt.combine([df, df * 2], np.add, keys=['c', 'd'])
c d
a b a b
x 4 5 7 9
y 7 8 12 14
>>> sr.vbt.combine(vbt.Param([1, 2], name='param'), np.add)
param 1 2
x 2 3
y 3 4
>>> # using class method
>>> sr.vbt.combine([df, df * 2], np.add, concat=False)
a b
x 10 13
y 17 20
```
* Using class method, templates, and broadcasting:
```pycon
>>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z'])
>>> sr.vbt.combine(
... [1, 2, 3],
... lambda x, y, z: x + y + z,
... vbt.Rep('df'),
... broadcast_named_args=dict(
... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
combine_idx 0 1 2
a b c a b c a b c
x 3 4 5 4 5 6 5 6 7
y 4 5 6 5 6 7 6 7 8
z 5 6 7 6 7 8 7 8 9
```
* To change the execution engine or specify other engine-related arguments, use `execute_kwargs`:
```pycon
>>> import time
>>> def combine_func(a, b):
... time.sleep(1)
... return a + b
>>> sr = pd.Series([1, 2, 3])
>>> %timeit sr.vbt.combine([1, 1, 1], combine_func)
3.01 s ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit sr.vbt.combine([1, 1, 1], combine_func, execute_kwargs=dict(engine='dask'))
1.02 s ± 2.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
"""
from vectorbtpro.indicators.factory import IndicatorBase
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(cls_or_self, type):
objs = obj
else:
if allow_multiple and isinstance(obj, (tuple, list)):
objs = obj
if concat is None:
concat = True
else:
objs = (obj,)
new_objs = []
for obj in objs:
if isinstance(obj, BaseAccessor):
obj = obj.obj
elif isinstance(obj, IndicatorBase):
obj = obj.main_output
new_objs.append(obj)
objs = tuple(new_objs)
if not isinstance(cls_or_self, type):
objs = (cls_or_self.obj,) + objs
if checks.is_numba_func(combine_func):
# Numba requires writeable arrays and in the same order
broadcast_kwargs = merge_dicts(dict(require_kwargs=dict(requirements=["W", "C"])), broadcast_kwargs)
# Broadcast and substitute templates
broadcast_named_args = {**{"obj_" + str(i): obj for i, obj in enumerate(objs)}, **broadcast_named_args}
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
if to_2d:
broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()}
elif not keep_pd:
broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()}
template_context = merge_dicts(broadcast_named_args, template_context)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
inputs = [broadcast_named_args["obj_" + str(i)] for i in range(len(objs))]
if concat is None:
concat = len(inputs) > 2
if concat:
# Concat the results horizontally
if isinstance(cls_or_self, type):
raise TypeError("Use instance method to concatenate")
out = combining.combine_and_concat(inputs[0], inputs[1:], combine_func, *args, **kwargs)
if keys is not None:
new_columns = indexes.combine_indexes([keys, wrapper.columns])
else:
top_columns = pd.Index(np.arange(len(objs) - 1), name="combine_idx")
new_columns = indexes.combine_indexes([top_columns, wrapper.columns])
return wrapper.wrap(out, **merge_dicts(dict(columns=new_columns, force_2d=True), wrap_kwargs))
else:
# Combine arguments pairwise into one object
out = combining.combine_multiple(inputs, combine_func, *args, **kwargs)
return wrapper.wrap(out, **resolve_dict(wrap_kwargs))
@classmethod
def eval(
cls,
expr: str,
frames_back: int = 1,
use_numexpr: bool = False,
numexpr_kwargs: tp.KwargsLike = None,
local_dict: tp.Optional[tp.Mapping] = None,
global_dict: tp.Optional[tp.Mapping] = None,
broadcast_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
):
"""Evaluate a simple array expression element-wise using NumExpr or NumPy.
If NumExpr is enables, only one-line statements are supported. Otherwise, uses
`vectorbtpro.utils.eval_.evaluate`.
!!! note
All required variables will broadcast against each other prior to the evaluation.
Usage:
```pycon
>>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z'])
>>> df = pd.DataFrame([[4, 5, 6]], index=['x', 'y', 'z'], columns=['a', 'b', 'c'])
>>> vbt.pd_acc.eval('sr + df')
a b c
x 5 6 7
y 6 7 8
z 7 8 9
```
"""
if numexpr_kwargs is None:
numexpr_kwargs = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
expr = inspect.cleandoc(expr)
parsed = ast.parse(expr)
body_nodes = list(parsed.body)
load_vars = set()
store_vars = set()
for body_node in body_nodes:
for child_node in ast.walk(body_node):
if type(child_node) is ast.Name:
if isinstance(child_node.ctx, ast.Load):
if child_node.id not in store_vars:
load_vars.add(child_node.id)
if isinstance(child_node.ctx, ast.Store):
store_vars.add(child_node.id)
load_vars = list(load_vars)
objs = get_context_vars(load_vars, frames_back=frames_back, local_dict=local_dict, global_dict=global_dict)
objs = dict(zip(load_vars, objs))
objs, wrapper = reshaping.broadcast(objs, return_wrapper=True, **broadcast_kwargs)
objs = {k: np.asarray(v) for k, v in objs.items()}
if use_numexpr:
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("numexpr")
import numexpr
out = numexpr.evaluate(expr, local_dict=objs, **numexpr_kwargs)
else:
out = evaluate(expr, context=objs)
return wrapper.wrap(out, **wrap_kwargs)
class BaseSRAccessor(BaseAccessor):
"""Accessor on top of Series.
Accessible via `pd.Series.vbt` and all child accessors."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
if _full_init:
if isinstance(wrapper, ArrayWrapper):
if wrapper.ndim == 2:
if wrapper.shape[1] == 1:
wrapper = wrapper.replace(ndim=1)
else:
raise TypeError("Series accessors work only one one-dimensional data")
BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@hybrid_property
def ndim(cls_or_self) -> int:
return 1
@hybrid_method
def is_series(cls_or_self) -> bool:
return True
@hybrid_method
def is_frame(cls_or_self) -> bool:
return False
class BaseDFAccessor(BaseAccessor):
"""Accessor on top of DataFrames.
Accessible via `pd.DataFrame.vbt` and all child accessors."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
if _full_init:
if isinstance(wrapper, ArrayWrapper):
if wrapper.ndim == 1:
wrapper = wrapper.replace(ndim=2)
BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@hybrid_property
def ndim(cls_or_self) -> int:
return 2
@hybrid_method
def is_series(cls_or_self) -> bool:
return False
@hybrid_method
def is_frame(cls_or_self) -> bool:
return True
</file>
<file path="base/chunking.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Extensions for chunking of base operations."""
import uuid
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.chunking import (
ArgGetter,
ArgSizer,
ArraySizer,
ChunkMeta,
ChunkMapper,
ChunkSlicer,
ShapeSlicer,
ArraySelector,
ArraySlicer,
Chunked,
)
from vectorbtpro.utils.parsing import Regex
__all__ = [
"GroupLensSizer",
"GroupLensSlicer",
"ChunkedGroupLens",
"GroupLensMapper",
"GroupMapSlicer",
"ChunkedGroupMap",
"GroupIdxsMapper",
"FlexArraySizer",
"FlexArraySelector",
"FlexArraySlicer",
"ChunkedFlexArray",
"shape_gl_slicer",
"flex_1d_array_gl_slicer",
"flex_array_gl_slicer",
"array_gl_slicer",
]
class GroupLensSizer(ArgSizer):
"""Class for getting the size from group lengths.
Argument can be either a group map tuple or a group lengths array."""
@classmethod
def get_obj_size(cls, obj: tp.Union[tp.GroupLens, tp.GroupMap], single_type: tp.Optional[type] = None) -> int:
"""Get size of an object."""
if single_type is not None:
if checks.is_instance_of(obj, single_type):
return 1
if isinstance(obj, tuple):
return len(obj[1])
return len(obj)
def get_size(self, ann_args: tp.AnnArgs, **kwargs) -> int:
return self.get_obj_size(self.get_arg(ann_args), single_type=self.single_type)
class GroupLensSlicer(ChunkSlicer):
"""Class for slicing multiple elements from group lengths based on the chunk range."""
def get_size(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], **kwargs) -> int:
return GroupLensSizer.get_obj_size(obj, single_type=self.single_type)
def take(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap:
if isinstance(obj, tuple):
return obj[1][chunk_meta.start : chunk_meta.end]
return obj[chunk_meta.start : chunk_meta.end]
class ChunkedGroupLens(Chunked):
"""Class representing chunkable group lengths."""
def resolve_take_spec(self) -> tp.TakeSpec:
if self.take_spec_missing:
if self.select:
raise ValueError("Selection is not supported")
return GroupLensSlicer
return self.take_spec
def get_group_lens_slice(group_lens: tp.GroupLens, chunk_meta: ChunkMeta) -> slice:
"""Get slice of each chunk in group lengths."""
group_lens_cumsum = np.cumsum(group_lens[: chunk_meta.end])
start = group_lens_cumsum[chunk_meta.start] - group_lens[chunk_meta.start]
end = group_lens_cumsum[-1]
return slice(start, end)
@define
class GroupLensMapper(ChunkMapper, ArgGetter, DefineMixin):
"""Class for mapping chunk metadata to per-group column lengths.
Argument can be either a group map tuple or a group lengths array."""
def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta:
group_lens = self.get_arg(ann_args)
if isinstance(group_lens, tuple):
group_lens = group_lens[1]
group_lens_slice = get_group_lens_slice(group_lens, chunk_meta)
return ChunkMeta(
uuid=str(uuid.uuid4()),
idx=chunk_meta.idx,
start=group_lens_slice.start,
end=group_lens_slice.stop,
indices=None,
)
group_lens_mapper = GroupLensMapper(arg_query=Regex(r"(group_lens|group_map)"))
"""Default instance of `GroupLensMapper`."""
class GroupMapSlicer(ChunkSlicer):
"""Class for slicing multiple elements from a group map based on the chunk range."""
def get_size(self, obj: tp.GroupMap, **kwargs) -> int:
return GroupLensSizer.get_obj_size(obj, single_type=self.single_type)
def take(self, obj: tp.GroupMap, chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap:
group_idxs, group_lens = obj
group_lens = group_lens[chunk_meta.start : chunk_meta.end]
return np.arange(np.sum(group_lens)), group_lens
class ChunkedGroupMap(Chunked):
"""Class representing a chunkable group map."""
def resolve_take_spec(self) -> tp.TakeSpec:
if self.take_spec_missing:
if self.select:
raise ValueError("Selection is not supported")
return GroupMapSlicer
return self.take_spec
@define
class GroupIdxsMapper(ChunkMapper, ArgGetter, DefineMixin):
"""Class for mapping chunk metadata to per-group column indices.
Argument must be a group map tuple."""
def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta:
group_map = self.get_arg(ann_args)
group_idxs, group_lens = group_map
group_lens_slice = get_group_lens_slice(group_lens, chunk_meta)
return ChunkMeta(
uuid=str(uuid.uuid4()),
idx=chunk_meta.idx,
start=None,
end=None,
indices=group_idxs[group_lens_slice],
)
group_idxs_mapper = GroupIdxsMapper(arg_query="group_map")
"""Default instance of `GroupIdxsMapper`."""
class FlexArraySizer(ArraySizer):
"""Class for getting the size from the length of an axis in a flexible array."""
@classmethod
def get_obj_size(cls, obj: tp.AnyArray, axis: int, single_type: tp.Optional[type] = None) -> int:
"""Get size of an object."""
if single_type is not None:
if checks.is_instance_of(obj, single_type):
return 1
obj = np.asarray(obj)
if len(obj.shape) == 0:
return 1
if axis is None:
if len(obj.shape) == 1:
axis = 0
checks.assert_not_none(axis, arg_name="axis")
checks.assert_in(axis, (0, 1), arg_name="axis")
if len(obj.shape) == 1:
if axis == 1:
return 1
return obj.shape[0]
if len(obj.shape) == 2:
if axis == 1:
return obj.shape[1]
return obj.shape[0]
raise ValueError(f"FlexArraySizer supports max 2 dimensions, not {len(obj.shape)}")
@define
class FlexArraySelector(ArraySelector, DefineMixin):
"""Class for selecting one element from a NumPy array's axis flexibly based on the chunk index.
The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb`
and `vectorbtpro.base.flex_indexing.flex_select_nb`."""
def get_size(self, obj: tp.ArrayLike, **kwargs) -> int:
return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type)
def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]:
return None
def take(
self,
obj: tp.ArrayLike,
chunk_meta: ChunkMeta,
ann_args: tp.Optional[tp.AnnArgs] = None,
**kwargs,
) -> tp.ArrayLike:
if np.isscalar(obj):
return obj
obj = np.asarray(obj)
if len(obj.shape) == 0:
return obj
axis = self.axis
if axis is None:
if len(obj.shape) == 1:
axis = 0
checks.assert_not_none(axis, arg_name="axis")
checks.assert_in(axis, (0, 1), arg_name="axis")
if len(obj.shape) == 1:
if axis == 1 or obj.shape[0] == 1:
return obj
if self.keep_dims:
return obj[chunk_meta.idx : chunk_meta.idx + 1]
return obj[chunk_meta.idx]
if len(obj.shape) == 2:
if axis == 1:
if obj.shape[1] == 1:
return obj
if self.keep_dims:
return obj[:, chunk_meta.idx : chunk_meta.idx + 1]
return obj[:, chunk_meta.idx]
if obj.shape[0] == 1:
return obj
if self.keep_dims:
return obj[chunk_meta.idx : chunk_meta.idx + 1, :]
return obj[chunk_meta.idx, :]
raise ValueError(f"FlexArraySelector supports max 2 dimensions, not {len(obj.shape)}")
@define
class FlexArraySlicer(ArraySlicer, DefineMixin):
"""Class for selecting one element from a NumPy array's axis flexibly based on the chunk index.
The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb`
and `vectorbtpro.base.flex_indexing.flex_select_nb`."""
def get_size(self, obj: tp.ArrayLike, **kwargs) -> int:
return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type)
def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]:
return None
def take(
self,
obj: tp.ArrayLike,
chunk_meta: ChunkMeta,
ann_args: tp.Optional[tp.AnnArgs] = None,
**kwargs,
) -> tp.ArrayLike:
if np.isscalar(obj):
return obj
obj = np.asarray(obj)
if len(obj.shape) == 0:
return obj
axis = self.axis
if axis is None:
if len(obj.shape) == 1:
axis = 0
checks.assert_not_none(axis, arg_name="axis")
checks.assert_in(axis, (0, 1), arg_name="axis")
if len(obj.shape) == 1:
if axis == 1 or obj.shape[0] == 1:
return obj
return obj[chunk_meta.start : chunk_meta.end]
if len(obj.shape) == 2:
if axis == 1:
if obj.shape[1] == 1:
return obj
return obj[:, chunk_meta.start : chunk_meta.end]
if obj.shape[0] == 1:
return obj
return obj[chunk_meta.start : chunk_meta.end, :]
raise ValueError(f"FlexArraySlicer supports max 2 dimensions, not {len(obj.shape)}")
class ChunkedFlexArray(Chunked):
"""Class representing a chunkable flexible array."""
def resolve_take_spec(self) -> tp.TakeSpec:
if self.take_spec_missing:
if self.select:
return FlexArraySelector
return FlexArraySlicer
return self.take_spec
shape_gl_slicer = ShapeSlicer(axis=1, mapper=group_lens_mapper)
"""Flexible 2-dim shape slicer along the column axis based on group lengths."""
flex_1d_array_gl_slicer = FlexArraySlicer(mapper=group_lens_mapper)
"""Flexible 1-dim array slicer along the column axis based on group lengths."""
flex_array_gl_slicer = FlexArraySlicer(axis=1, mapper=group_lens_mapper)
"""Flexible 2-dim array slicer along the column axis based on group lengths."""
array_gl_slicer = ArraySlicer(axis=1, mapper=group_lens_mapper)
"""2-dim array slicer along the column axis based on group lengths."""
</file>
<file path="base/combining.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for combining arrays.
Combine functions combine two or more NumPy arrays using a custom function. The emphasis here is
done upon stacking the results into one NumPy array - since vectorbt is all about brute-forcing
large spaces of hyper-parameters, concatenating the results of each hyper-parameter combination into
a single DataFrame is important. All functions are available in both Python and Numba-compiled form."""
import numpy as np
from numba.typed import List
from vectorbtpro import _typing as tp
from vectorbtpro.registries.jit_registry import jit_reg, register_jitted
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.template import RepFunc
__all__ = []
@register_jitted
def custom_apply_and_concat_none_nb(
indices: tp.Array1d,
apply_func_nb: tp.Callable,
*args,
) -> None:
"""Run `apply_func_nb` that returns nothing for each index.
Meant for in-place outputs."""
for i in indices:
apply_func_nb(i, *args)
@register_jitted
def apply_and_concat_none_nb(
ntimes: int,
apply_func_nb: tp.Callable,
*args,
) -> None:
"""Run `apply_func_nb` that returns nothing number of times.
Uses `custom_apply_and_concat_none_nb`."""
custom_apply_and_concat_none_nb(np.arange(ntimes), apply_func_nb, *args)
@register_jitted
def to_2d_one_nb(a: tp.Array) -> tp.Array2d:
"""Expand the dimensions of the array along the axis 1."""
if a.ndim > 1:
return a
return np.expand_dims(a, axis=1)
@register_jitted
def custom_apply_and_concat_one_nb(
indices: tp.Array1d,
apply_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Run `apply_func_nb` that returns one array for each index."""
output_0 = to_2d_one_nb(apply_func_nb(indices[0], *args))
output = np.empty((output_0.shape[0], len(indices) * output_0.shape[1]), dtype=output_0.dtype)
for i in range(len(indices)):
if i == 0:
outputs_i = output_0
else:
outputs_i = to_2d_one_nb(apply_func_nb(indices[i], *args))
output[:, i * outputs_i.shape[1] : (i + 1) * outputs_i.shape[1]] = outputs_i
return output
@register_jitted
def apply_and_concat_one_nb(
ntimes: int,
apply_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Run `apply_func_nb` that returns one array number of times.
Uses `custom_apply_and_concat_one_nb`."""
return custom_apply_and_concat_one_nb(np.arange(ntimes), apply_func_nb, *args)
@register_jitted
def to_2d_multiple_nb(a: tp.Iterable[tp.Array]) -> tp.List[tp.Array2d]:
"""Expand the dimensions of each array in `a` along axis 1."""
lst = list()
for _a in a:
lst.append(to_2d_one_nb(_a))
return lst
@register_jitted
def custom_apply_and_concat_multiple_nb(
indices: tp.Array1d,
apply_func_nb: tp.Callable,
*args,
) -> tp.List[tp.Array2d]:
"""Run `apply_func_nb` that returns multiple arrays for each index."""
outputs = list()
outputs_0 = to_2d_multiple_nb(apply_func_nb(indices[0], *args))
for j in range(len(outputs_0)):
outputs.append(
np.empty((outputs_0[j].shape[0], len(indices) * outputs_0[j].shape[1]), dtype=outputs_0[j].dtype)
)
for i in range(len(indices)):
if i == 0:
outputs_i = outputs_0
else:
outputs_i = to_2d_multiple_nb(apply_func_nb(indices[i], *args))
for j in range(len(outputs_i)):
outputs[j][:, i * outputs_i[j].shape[1] : (i + 1) * outputs_i[j].shape[1]] = outputs_i[j]
return outputs
@register_jitted
def apply_and_concat_multiple_nb(
ntimes: int,
apply_func_nb: tp.Callable,
*args,
) -> tp.List[tp.Array2d]:
"""Run `apply_func_nb` that returns multiple arrays number of times.
Uses `custom_apply_and_concat_multiple_nb`."""
return custom_apply_and_concat_multiple_nb(np.arange(ntimes), apply_func_nb, *args)
def apply_and_concat_each(
tasks: tp.TasksLike,
n_outputs: tp.Optional[int] = None,
execute_kwargs: tp.KwargsLike = None,
) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]:
"""Apply each function on its own set of positional and keyword arguments.
Executes the function using `vectorbtpro.utils.execution.execute`."""
from vectorbtpro.base.merging import column_stack_arrays
if execute_kwargs is None:
execute_kwargs = {}
out = execute(tasks, **execute_kwargs)
if n_outputs is None:
if out[0] is None:
n_outputs = 0
elif isinstance(out[0], (tuple, list, List)):
n_outputs = len(out[0])
else:
n_outputs = 1
if n_outputs == 0:
return None
if n_outputs == 1:
if isinstance(out[0], (tuple, list, List)) and len(out[0]) == 1:
out = list(map(lambda x: x[0], out))
return column_stack_arrays(out)
return list(map(column_stack_arrays, zip(*out)))
def apply_and_concat(
ntimes: int,
apply_func: tp.Callable,
*args,
n_outputs: tp.Optional[int] = None,
jitted_loop: bool = False,
jitted_warmup: bool = False,
execute_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]:
"""Run `apply_func` function a number of times and concatenate the results depending upon how
many array-like objects it generates.
`apply_func` must accept arguments `i`, `*args`, and `**kwargs`.
Set `jitted_loop` to True to use the JIT-compiled version.
All jitted iteration functions are resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`.
!!! note
`n_outputs` must be set when `jitted_loop` is True.
Numba doesn't support variable keyword arguments."""
if jitted_loop:
if n_outputs is None:
raise ValueError("Jitted iteration requires n_outputs")
if n_outputs == 0:
func = jit_reg.resolve(custom_apply_and_concat_none_nb)
elif n_outputs == 1:
func = jit_reg.resolve(custom_apply_and_concat_one_nb)
else:
func = jit_reg.resolve(custom_apply_and_concat_multiple_nb)
if jitted_warmup:
func(np.array([0]), apply_func, *args, **kwargs)
def _tasks_template(chunk_meta):
tasks = []
for _chunk_meta in chunk_meta:
if _chunk_meta.indices is not None:
chunk_indices = np.asarray(_chunk_meta.indices)
else:
if _chunk_meta.start is None or _chunk_meta.end is None:
raise ValueError("Each chunk must have a start and an end index")
chunk_indices = np.arange(_chunk_meta.start, _chunk_meta.end)
tasks.append(Task(func, chunk_indices, apply_func, *args, **kwargs))
return tasks
tasks = RepFunc(_tasks_template)
else:
tasks = [(apply_func, (i, *args), kwargs) for i in range(ntimes)]
if execute_kwargs is None:
execute_kwargs = {}
execute_kwargs["size"] = ntimes
return apply_and_concat_each(
tasks,
n_outputs=n_outputs,
execute_kwargs=execute_kwargs,
)
@register_jitted
def select_and_combine_nb(
i: int,
obj: tp.Any,
others: tp.Sequence,
combine_func_nb: tp.Callable,
*args,
) -> tp.AnyArray:
"""Numba-compiled version of `select_and_combine`."""
return combine_func_nb(obj, others[i], *args)
@register_jitted
def combine_and_concat_nb(
obj: tp.Any,
others: tp.Sequence,
combine_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Numba-compiled version of `combine_and_concat`."""
return apply_and_concat_one_nb(len(others), select_and_combine_nb, obj, others, combine_func_nb, *args)
def select_and_combine(
i: int,
obj: tp.Any,
others: tp.Sequence,
combine_func: tp.Callable,
*args,
**kwargs,
) -> tp.AnyArray:
"""Combine `obj` with an array at position `i` in `others` using `combine_func`."""
return combine_func(obj, others[i], *args, **kwargs)
def combine_and_concat(
obj: tp.Any,
others: tp.Sequence,
combine_func: tp.Callable,
*args,
jitted_loop: bool = False,
**kwargs,
) -> tp.Array2d:
"""Combine `obj` with each in `others` using `combine_func` and concatenate.
`select_and_combine_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`."""
if jitted_loop:
apply_func = jit_reg.resolve(select_and_combine_nb)
else:
apply_func = select_and_combine
return apply_and_concat(
len(others),
apply_func,
obj,
others,
combine_func,
*args,
n_outputs=1,
jitted_loop=jitted_loop,
**kwargs,
)
@register_jitted
def combine_multiple_nb(
objs: tp.Sequence,
combine_func_nb: tp.Callable,
*args,
) -> tp.Any:
"""Numba-compiled version of `combine_multiple`."""
result = objs[0]
for i in range(1, len(objs)):
result = combine_func_nb(result, objs[i], *args)
return result
def combine_multiple(
objs: tp.Sequence,
combine_func: tp.Callable,
*args,
jitted_loop: bool = False,
**kwargs,
) -> tp.Any:
"""Combine `objs` pairwise into a single object.
Set `jitted_loop` to True to use the JIT-compiled version.
`combine_multiple_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`.
!!! note
Numba doesn't support variable keyword arguments."""
if jitted_loop:
func = jit_reg.resolve(combine_multiple_nb)
return func(objs, combine_func, *args)
result = objs[0]
for i in range(1, len(objs)):
result = combine_func(result, objs[i], *args, **kwargs)
return result
</file>
<file path="base/decorators.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for base classes."""
from functools import cached_property as cachedproperty
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import Config, HybridConfig, merge_dicts
__all__ = []
def override_arg_config(config: Config, merge_configs: bool = True) -> tp.ClassWrapper:
"""Class decorator to override the argument config of a class subclassing
`vectorbtpro.base.preparing.BasePreparer`.
Instead of overriding `_arg_config` class attribute, you can pass `config` directly to this decorator.
Disable `merge_configs` to not merge, which will effectively disable field inheritance."""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "BasePreparer")
if merge_configs:
new_config = merge_dicts(cls.arg_config, config)
else:
new_config = config
if not isinstance(new_config, Config):
new_config = HybridConfig(new_config)
setattr(cls, "_arg_config", new_config)
return cls
return wrapper
def attach_arg_properties(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
"""Class decorator to attach properties for arguments defined in the argument config
of a `vectorbtpro.base.preparing.BasePreparer` subclass."""
checks.assert_subclass_of(cls, "BasePreparer")
for arg_name, settings in cls.arg_config.items():
attach = settings.get("attach", None)
broadcast = settings.get("broadcast", False)
substitute_templates = settings.get("substitute_templates", False)
if (isinstance(attach, bool) and attach) or (attach is None and (broadcast or substitute_templates)):
if broadcast:
return_type = tp.ArrayLike
else:
return_type = object
target_pre_name = "_pre_" + arg_name
if not hasattr(cls, target_pre_name):
def pre_arg_prop(self, _arg_name: str = arg_name) -> return_type:
return self.get_arg(_arg_name)
pre_arg_prop.__name__ = target_pre_name
pre_arg_prop.__module__ = cls.__module__
pre_arg_prop.__qualname__ = f"{cls.__name__}.{pre_arg_prop.__name__}"
if broadcast and substitute_templates:
pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting and template substitution."
elif broadcast:
pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting."
else:
pre_arg_prop.__doc__ = f"Argument `{arg_name}` before template substitution."
setattr(cls, pre_arg_prop.__name__, cachedproperty(pre_arg_prop))
getattr(cls, pre_arg_prop.__name__).__set_name__(cls, pre_arg_prop.__name__)
target_post_name = "_post_" + arg_name
if not hasattr(cls, target_post_name):
def post_arg_prop(self, _arg_name: str = arg_name) -> return_type:
return self.prepare_post_arg(_arg_name)
post_arg_prop.__name__ = target_post_name
post_arg_prop.__module__ = cls.__module__
post_arg_prop.__qualname__ = f"{cls.__name__}.{post_arg_prop.__name__}"
if broadcast and substitute_templates:
post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting and template substitution."
elif broadcast:
post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting."
else:
post_arg_prop.__doc__ = f"Argument `{arg_name}` after template substitution."
setattr(cls, post_arg_prop.__name__, cachedproperty(post_arg_prop))
getattr(cls, post_arg_prop.__name__).__set_name__(cls, post_arg_prop.__name__)
target_name = arg_name
if not hasattr(cls, target_name):
def arg_prop(self, _target_post_name: str = target_post_name) -> return_type:
return getattr(self, _target_post_name)
arg_prop.__name__ = target_name
arg_prop.__module__ = cls.__module__
arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}"
arg_prop.__doc__ = f"Argument `{arg_name}`."
setattr(cls, arg_prop.__name__, cachedproperty(arg_prop))
getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__)
elif (isinstance(attach, bool) and attach) or attach is None:
if not hasattr(cls, arg_name):
def arg_prop(self, _arg_name: str = arg_name) -> tp.Any:
return self.get_arg(_arg_name)
arg_prop.__name__ = arg_name
arg_prop.__module__ = cls.__module__
arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}"
arg_prop.__doc__ = f"Argument `{arg_name}`."
setattr(cls, arg_prop.__name__, cachedproperty(arg_prop))
getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__)
return cls
</file>
<file path="base/flex_indexing.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes and functions for flexible indexing."""
from vectorbtpro import _typing as tp
from vectorbtpro._settings import settings
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = [
"flex_select_1d_nb",
"flex_select_1d_pr_nb",
"flex_select_1d_pc_nb",
"flex_select_nb",
"flex_select_row_nb",
"flex_select_col_nb",
"flex_select_2d_row_nb",
"flex_select_2d_col_nb",
]
_rotate_rows = settings["indexing"]["rotate_rows"]
_rotate_cols = settings["indexing"]["rotate_cols"]
@register_jitted(cache=True)
def flex_choose_i_1d_nb(arr: tp.FlexArray1d, i: int) -> int:
"""Choose a position in an array as if it has been broadcast against rows or columns.
!!! note
Array must be one-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
return int(flex_i)
@register_jitted(cache=True)
def flex_select_1d_nb(arr: tp.FlexArray1d, i: int) -> tp.Scalar:
"""Select an element of an array as if it has been broadcast against rows or columns.
!!! note
Array must be one-dimensional."""
flex_i = flex_choose_i_1d_nb(arr, i)
return arr[flex_i]
@register_jitted(cache=True)
def flex_choose_i_pr_1d_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> int:
"""Choose a position in an array as if it has been broadcast against rows.
Can use rotational indexing along rows.
!!! note
Array must be one-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
if rotate_rows:
return int(flex_i) % arr.shape[0]
return int(flex_i)
@register_jitted(cache=True)
def flex_choose_i_pr_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> int:
"""Choose a position in an array as if it has been broadcast against rows.
Can use rotational indexing along rows.
!!! note
Array must be two-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
if rotate_rows:
return int(flex_i) % arr.shape[0]
return int(flex_i)
@register_jitted(cache=True)
def flex_select_1d_pr_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Scalar:
"""Select an element of an array as if it has been broadcast against rows.
Can use rotational indexing along rows.
!!! note
Array must be one-dimensional."""
flex_i = flex_choose_i_pr_1d_nb(arr, i, rotate_rows=rotate_rows)
return arr[flex_i]
@register_jitted(cache=True)
def flex_choose_i_pc_1d_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> int:
"""Choose a position in an array as if it has been broadcast against columns.
Can use rotational indexing along columns.
!!! note
Array must be one-dimensional."""
if arr.shape[0] == 1:
flex_col = 0
else:
flex_col = col
if rotate_cols:
return int(flex_col) % arr.shape[0]
return int(flex_col)
@register_jitted(cache=True)
def flex_choose_i_pc_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> int:
"""Choose a position in an array as if it has been broadcast against columns.
Can use rotational indexing along columns.
!!! note
Array must be two-dimensional."""
if arr.shape[1] == 1:
flex_col = 0
else:
flex_col = col
if rotate_cols:
return int(flex_col) % arr.shape[1]
return int(flex_col)
@register_jitted(cache=True)
def flex_select_1d_pc_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Scalar:
"""Select an element of an array as if it has been broadcast against columns.
Can use rotational indexing along columns.
!!! note
Array must be one-dimensional."""
flex_col = flex_choose_i_pc_1d_nb(arr, col, rotate_cols=rotate_cols)
return arr[flex_col]
@register_jitted(cache=True)
def flex_choose_i_and_col_nb(
arr: tp.FlexArray2d,
i: int,
col: int,
rotate_rows: bool = _rotate_rows,
rotate_cols: bool = _rotate_cols,
) -> tp.Tuple[int, int]:
"""Choose a position in an array as if it has been broadcast rows and columns.
Can use rotational indexing along rows and columns.
!!! note
Array must be two-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
if arr.shape[1] == 1:
flex_col = 0
else:
flex_col = col
if rotate_rows and rotate_cols:
return int(flex_i) % arr.shape[0], int(flex_col) % arr.shape[1]
if rotate_rows:
return int(flex_i) % arr.shape[0], int(flex_col)
if rotate_cols:
return int(flex_i), int(flex_col) % arr.shape[1]
return int(flex_i), int(flex_col)
@register_jitted(cache=True)
def flex_select_nb(
arr: tp.FlexArray2d,
i: int,
col: int,
rotate_rows: bool = _rotate_rows,
rotate_cols: bool = _rotate_cols,
) -> tp.Scalar:
"""Select element of an array as if it has been broadcast rows and columns.
Can use rotational indexing along rows and columns.
!!! note
Array must be two-dimensional."""
flex_i, flex_col = flex_choose_i_and_col_nb(
arr,
i,
col,
rotate_rows=rotate_rows,
rotate_cols=rotate_cols,
)
return arr[flex_i, flex_col]
@register_jitted(cache=True)
def flex_select_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array1d:
"""Select a row from a flexible 2-dim array. Returns a 1-dim array.
!!! note
Array must be two-dimensional."""
flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows)
return arr[flex_i]
@register_jitted(cache=True)
def flex_select_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array1d:
"""Select a column from a flexible 2-dim array. Returns a 1-dim array.
!!! note
Array must be two-dimensional."""
flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols)
return arr[:, flex_col]
@register_jitted(cache=True)
def flex_select_2d_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array2d:
"""Select a row from a flexible 2-dim array. Returns a 2-dim array.
!!! note
Array must be two-dimensional."""
flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows)
return arr[flex_i : flex_i + 1]
@register_jitted(cache=True)
def flex_select_2d_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array2d:
"""Select a column from a flexible 2-dim array. Returns a 2-dim array.
!!! note
Array must be two-dimensional."""
flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols)
return arr[:, flex_col : flex_col + 1]
</file>
<file path="base/indexes.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for working with indexes: index and columns.
They perform operations on index objects, such as stacking, combining, and cleansing MultiIndex levels.
!!! note
"Index" in pandas context is referred to both index and columns."""
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import jit_reg, register_jitted
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.base import Base
__all__ = [
"ExceptLevel",
"repeat_index",
"tile_index",
"stack_indexes",
"combine_indexes",
]
@define
class ExceptLevel(DefineMixin):
"""Class for grouping except one or more levels."""
value: tp.MaybeLevelSequence = define.field()
"""One or more level positions or names."""
def to_any_index(index_like: tp.IndexLike) -> tp.Index:
"""Convert any index-like object to an index.
Index objects are kept as-is."""
if checks.is_np_array(index_like) and index_like.ndim == 0:
index_like = index_like[None]
if not checks.is_index(index_like):
return pd.Index(index_like)
return index_like
def get_index(obj: tp.SeriesFrame, axis: int) -> tp.Index:
"""Get index of `obj` by `axis`."""
checks.assert_instance_of(obj, (pd.Series, pd.DataFrame))
checks.assert_in(axis, (0, 1))
if axis == 0:
return obj.index
else:
if checks.is_series(obj):
if obj.name is not None:
return pd.Index([obj.name])
return pd.Index([0]) # same as how pandas does it
else:
return obj.columns
def index_from_values(
values: tp.Sequence,
single_value: bool = False,
name: tp.Optional[tp.Hashable] = None,
) -> tp.Index:
"""Create a new `pd.Index` with `name` by parsing an iterable `values`.
Each in `values` will correspond to an element in the new index."""
scalar_types = (int, float, complex, str, bool, datetime, timedelta, np.generic)
type_id_number = {}
value_names = []
if len(values) == 1:
single_value = True
for i in range(len(values)):
if i > 0 and single_value:
break
v = values[i]
if v is None or isinstance(v, scalar_types):
value_names.append(v)
elif isinstance(v, np.ndarray):
all_same = False
if np.issubdtype(v.dtype, np.floating):
if np.isclose(v, v.item(0), equal_nan=True).all():
all_same = True
elif v.dtype.names is not None:
all_same = False
else:
if np.equal(v, v.item(0)).all():
all_same = True
if all_same:
value_names.append(v.item(0))
else:
if single_value:
value_names.append("array")
else:
if "array" not in type_id_number:
type_id_number["array"] = {}
if id(v) not in type_id_number["array"]:
type_id_number["array"][id(v)] = len(type_id_number["array"])
value_names.append("array_%d" % (type_id_number["array"][id(v)]))
else:
type_name = str(type(v).__name__)
if single_value:
value_names.append("%s" % type_name)
else:
if type_name not in type_id_number:
type_id_number[type_name] = {}
if id(v) not in type_id_number[type_name]:
type_id_number[type_name][id(v)] = len(type_id_number[type_name])
value_names.append("%s_%d" % (type_name, type_id_number[type_name][id(v)]))
if single_value and len(values) > 1:
value_names *= len(values)
return pd.Index(value_names, name=name)
def repeat_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index:
"""Repeat each element in `index` `n` times.
Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if ignore_ranges is None:
ignore_ranges = broadcasting_cfg["ignore_ranges"]
index = to_any_index(index)
if n == 1:
return index
if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name
return pd.RangeIndex(start=0, stop=len(index) * n, step=1)
return index.repeat(n)
def tile_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index:
"""Tile the whole `index` `n` times.
Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if ignore_ranges is None:
ignore_ranges = broadcasting_cfg["ignore_ranges"]
index = to_any_index(index)
if n == 1:
return index
if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name
return pd.RangeIndex(start=0, stop=len(index) * n, step=1)
if isinstance(index, pd.MultiIndex):
return pd.MultiIndex.from_tuples(np.tile(index, n), names=index.names)
return pd.Index(np.tile(index, n), name=index.name)
def clean_index(
index: tp.IndexLike,
drop_duplicates: tp.Optional[bool] = None,
keep: tp.Optional[str] = None,
drop_redundant: tp.Optional[bool] = None,
) -> tp.Index:
"""Clean index.
Set `drop_duplicates` to True to remove duplicate levels.
For details on `keep`, see `drop_duplicate_levels`.
Set `drop_redundant` to True to use `drop_redundant_levels`."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if drop_duplicates is None:
drop_duplicates = broadcasting_cfg["drop_duplicates"]
if keep is None:
keep = broadcasting_cfg["keep"]
if drop_redundant is None:
drop_redundant = broadcasting_cfg["drop_redundant"]
index = to_any_index(index)
if drop_duplicates:
index = drop_duplicate_levels(index, keep=keep)
if drop_redundant:
index = drop_redundant_levels(index)
return index
def stack_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **clean_index_kwargs) -> tp.Index:
"""Stack each index in `indexes` on top of each other, from top to bottom."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
levels = []
for i in range(len(indexes)):
index = indexes[i]
if not isinstance(index, pd.MultiIndex):
levels.append(to_any_index(index))
else:
for j in range(index.nlevels):
levels.append(index.get_level_values(j))
max_len = max(map(len, levels))
for i in range(len(levels)):
if len(levels[i]) < max_len:
if len(levels[i]) != 1:
raise ValueError(f"Index at level {i} could not be broadcast to shape ({max_len},) ")
levels[i] = repeat_index(levels[i], max_len, ignore_ranges=False)
new_index = pd.MultiIndex.from_arrays(levels)
return clean_index(new_index, **clean_index_kwargs)
def combine_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **kwargs) -> tp.Index:
"""Combine each index in `indexes` using Cartesian product.
Keyword arguments will be passed to `stack_indexes`."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
new_index = to_any_index(indexes[0])
for i in range(1, len(indexes)):
index1, index2 = new_index, to_any_index(indexes[i])
new_index1 = repeat_index(index1, len(index2), ignore_ranges=False)
new_index2 = tile_index(index2, len(index1), ignore_ranges=False)
new_index = stack_indexes([new_index1, new_index2], **kwargs)
return new_index
def combine_index_with_keys(index: tp.IndexLike, keys: tp.IndexLike, lens: tp.Sequence[int], **kwargs) -> tp.Index:
"""Build keys based on index lengths."""
if not isinstance(index, pd.Index):
index = pd.Index(index)
if not isinstance(keys, pd.Index):
keys = pd.Index(keys)
new_index = None
new_keys = None
start_idx = 0
for i in range(len(keys)):
_index = index[start_idx : start_idx + lens[i]]
if new_index is None:
new_index = _index
else:
new_index = new_index.append(_index)
start_idx += lens[i]
new_key = keys[[i]].repeat(lens[i])
if new_keys is None:
new_keys = new_key
else:
new_keys = new_keys.append(new_key)
return stack_indexes([new_keys, new_index], **kwargs)
def concat_indexes(
*indexes: tp.MaybeTuple[tp.IndexLike],
index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append",
keys: tp.Optional[tp.IndexLike] = None,
clean_index_kwargs: tp.KwargsLike = None,
verify_integrity: bool = True,
axis: int = 1,
) -> tp.Index:
"""Concatenate indexes.
The following index concatenation methods are supported:
* 'append': append one index to another
* 'union': build a union of indexes
* 'pd_concat': convert indexes to Pandas Series or DataFrames and use `pd.concat`
* 'factorize': factorize the concatenated index
* 'factorize_each': factorize each index and concatenate while keeping numbers unique
* 'reset': reset the concatenated index without applying `keys`
* Callable: a custom callable that takes the indexes and returns the concatenated index
Argument `index_concat_method` also accepts a tuple of two options: the second option gets applied
if the first one fails.
Use `keys` as an index with the same number of elements as there are indexes to add
another index level on top of the concatenated indexes.
If `verify_integrity` is True and `keys` is None, performs various checks depending on the axis."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
if clean_index_kwargs is None:
clean_index_kwargs = {}
if axis == 0:
factorized_name = "row_idx"
elif axis == 1:
factorized_name = "col_idx"
else:
factorized_name = "group_idx"
if keys is None:
all_ranges = True
for index in indexes:
if not checks.is_default_index(index):
all_ranges = False
break
if all_ranges:
return pd.RangeIndex(stop=sum(map(len, indexes)))
if isinstance(index_concat_method, tuple):
try:
return concat_indexes(
*indexes,
index_concat_method=index_concat_method[0],
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=axis,
)
except Exception as e:
return concat_indexes(
*indexes,
index_concat_method=index_concat_method[1],
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=axis,
)
if not isinstance(index_concat_method, str):
new_index = index_concat_method(indexes)
elif index_concat_method.lower() == "append":
new_index = None
for index in indexes:
if new_index is None:
new_index = index
else:
new_index = new_index.append(index)
elif index_concat_method.lower() == "union":
if keys is not None:
raise ValueError("Cannot apply keys after concatenating indexes through union")
new_index = None
for index in indexes:
if new_index is None:
new_index = index
else:
new_index = new_index.union(index)
elif index_concat_method.lower() == "pd_concat":
new_index = None
for index in indexes:
if isinstance(index, pd.MultiIndex):
index = index.to_frame().reset_index(drop=True)
else:
index = index.to_series().reset_index(drop=True)
if new_index is None:
new_index = index
else:
if isinstance(new_index, pd.DataFrame):
if isinstance(index, pd.Series):
index = index.to_frame()
elif isinstance(index, pd.Series):
if isinstance(new_index, pd.DataFrame):
new_index = new_index.to_frame()
new_index = pd.concat((new_index, index), ignore_index=True)
if isinstance(new_index, pd.Series):
new_index = pd.Index(new_index)
else:
new_index = pd.MultiIndex.from_frame(new_index)
elif index_concat_method.lower() == "factorize":
new_index = concat_indexes(
*indexes,
index_concat_method="append",
clean_index_kwargs=clean_index_kwargs,
verify_integrity=False,
axis=axis,
)
new_index = pd.Index(pd.factorize(new_index)[0], name=factorized_name)
elif index_concat_method.lower() == "factorize_each":
new_index = None
for index in indexes:
index = pd.Index(pd.factorize(index)[0], name=factorized_name)
if new_index is None:
new_index = index
next_min = index.max() + 1
else:
new_index = new_index.append(index + next_min)
next_min = index.max() + 1 + next_min
elif index_concat_method.lower() == "reset":
return pd.RangeIndex(stop=sum(map(len, indexes)))
else:
if axis == 0:
raise ValueError(f"Invalid index concatenation method: '{index_concat_method}'")
elif axis == 1:
raise ValueError(f"Invalid column concatenation method: '{index_concat_method}'")
else:
raise ValueError(f"Invalid group concatenation method: '{index_concat_method}'")
if keys is not None:
if isinstance(keys[0], pd.Index):
keys = concat_indexes(
*keys,
index_concat_method="append",
clean_index_kwargs=clean_index_kwargs,
verify_integrity=False,
axis=axis,
)
new_index = stack_indexes((keys, new_index), **clean_index_kwargs)
keys = None
elif not isinstance(keys, pd.Index):
keys = pd.Index(keys)
if keys is not None:
top_index = None
for i, index in enumerate(indexes):
repeated_index = repeat_index(keys[[i]], len(index))
if top_index is None:
top_index = repeated_index
else:
top_index = top_index.append(repeated_index)
new_index = stack_indexes((top_index, new_index), **clean_index_kwargs)
if verify_integrity:
if keys is None:
if axis == 0:
if not new_index.is_monotonic_increasing:
raise ValueError("Concatenated index is not monotonically increasing")
if "mixed" in new_index.inferred_type:
raise ValueError("Concatenated index is mixed")
if new_index.has_duplicates:
raise ValueError("Concatenated index contains duplicates")
if axis == 1:
if new_index.has_duplicates:
raise ValueError("Concatenated columns contain duplicates")
if axis == 2:
if new_index.has_duplicates:
len_sum = 0
for index in indexes:
if len_sum > 0:
prev_index = new_index[:len_sum]
this_index = new_index[len_sum : len_sum + len(index)]
if len(prev_index.intersection(this_index)) > 0:
raise ValueError("Concatenated groups contain duplicates")
len_sum += len(index)
return new_index
def drop_levels(
index: tp.Index,
levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
) -> tp.Index:
"""Drop `levels` in `index` by their name(s)/position(s).
Provide `levels` as an instance of `ExceptLevel` to drop everything apart from the specified levels."""
if not isinstance(index, pd.MultiIndex):
if strict:
raise TypeError("Index must be a multi-index")
return index
if isinstance(levels, ExceptLevel):
levels = levels.value
except_mode = True
else:
except_mode = False
levels_to_drop = set()
if isinstance(levels, str) or not checks.is_sequence(levels):
levels = [levels]
for level in levels:
if level in index.names:
for level_pos in [i for i, x in enumerate(index.names) if x == level]:
levels_to_drop.add(level_pos)
elif checks.is_int(level):
if level < 0:
new_level = index.nlevels + level
if new_level < 0:
raise KeyError(f"Level at position {level} not found")
level = new_level
if 0 <= level < index.nlevels:
levels_to_drop.add(level)
else:
raise KeyError(f"Level at position {level} not found")
elif strict:
raise KeyError(f"Level '{level}' not found")
if except_mode:
levels_to_drop = set(range(index.nlevels)).difference(levels_to_drop)
if len(levels_to_drop) == 0:
if strict:
raise ValueError("No levels to drop")
return index
if len(levels_to_drop) >= index.nlevels:
if strict:
raise ValueError(
f"Cannot remove {len(levels_to_drop)} levels from an index with {index.nlevels} levels: "
"at least one level must be left"
)
return index
return index.droplevel(list(levels_to_drop))
def rename_levels(index: tp.Index, mapper: tp.MaybeMappingSequence[tp.Level], strict: bool = True) -> tp.Index:
"""Rename levels in `index` by `mapper`.
Mapper can be a single or multiple levels to rename to, or a dictionary that maps
old level names to new level names."""
if isinstance(index, pd.MultiIndex):
nlevels = index.nlevels
if isinstance(mapper, (int, str)):
mapper = dict(zip(index.names, [mapper]))
elif checks.is_complex_sequence(mapper):
mapper = dict(zip(index.names, mapper))
else:
nlevels = 1
if isinstance(mapper, (int, str)):
mapper = dict(zip([index.name], [mapper]))
elif checks.is_complex_sequence(mapper):
mapper = dict(zip([index.name], mapper))
for k, v in mapper.items():
if k in index.names:
if isinstance(index, pd.MultiIndex):
index = index.rename(v, level=k)
else:
index = index.rename(v)
elif checks.is_int(k):
if k < 0:
new_k = nlevels + k
if new_k < 0:
raise KeyError(f"Level at position {k} not found")
k = new_k
if 0 <= k < nlevels:
if isinstance(index, pd.MultiIndex):
index = index.rename(v, level=k)
else:
index = index.rename(v)
else:
raise KeyError(f"Level at position {k} not found")
elif strict:
raise KeyError(f"Level '{k}' not found")
return index
def select_levels(
index: tp.Index,
levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
) -> tp.Index:
"""Build a new index by selecting one or multiple `levels` from `index`.
Provide `levels` as an instance of `ExceptLevel` to select everything apart from the specified levels."""
was_multiindex = True
if not isinstance(index, pd.MultiIndex):
was_multiindex = False
index = pd.MultiIndex.from_arrays([index])
if isinstance(levels, ExceptLevel):
levels = levels.value
except_mode = True
else:
except_mode = False
levels_to_select = list()
if isinstance(levels, str) or not checks.is_sequence(levels):
levels = [levels]
single_mode = True
else:
single_mode = False
for level in levels:
if level in index.names:
for level_pos in [i for i, x in enumerate(index.names) if x == level]:
if level_pos not in levels_to_select:
levels_to_select.append(level_pos)
elif checks.is_int(level):
if level < 0:
new_level = index.nlevels + level
if new_level < 0:
raise KeyError(f"Level at position {level} not found")
level = new_level
if 0 <= level < index.nlevels:
if level not in levels_to_select:
levels_to_select.append(level)
else:
raise KeyError(f"Level at position {level} not found")
elif strict:
raise KeyError(f"Level '{level}' not found")
if except_mode:
levels_to_select = list(set(range(index.nlevels)).difference(levels_to_select))
if len(levels_to_select) == 0:
if strict:
raise ValueError("No levels to select")
if not was_multiindex:
return index.get_level_values(0)
return index
if len(levels_to_select) == 1 and single_mode:
return index.get_level_values(levels_to_select[0])
levels = [index.get_level_values(level) for level in levels_to_select]
return pd.MultiIndex.from_arrays(levels)
def drop_redundant_levels(index: tp.Index) -> tp.Index:
"""Drop levels in `index` that either have a single unnamed value or a range from 0 to n."""
if not isinstance(index, pd.MultiIndex):
return index
levels_to_drop = []
for i in range(index.nlevels):
if len(index.levels[i]) == 1 and index.levels[i].name is None:
levels_to_drop.append(i)
elif checks.is_default_index(index.get_level_values(i)):
levels_to_drop.append(i)
if len(levels_to_drop) < index.nlevels:
return index.droplevel(levels_to_drop)
return index
def drop_duplicate_levels(index: tp.Index, keep: tp.Optional[str] = None) -> tp.Index:
"""Drop levels in `index` with the same name and values.
Set `keep` to 'last' to keep last levels, otherwise 'first'.
Set `keep` to None to use the default."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if keep is None:
keep = broadcasting_cfg["keep"]
if not isinstance(index, pd.MultiIndex):
return index
checks.assert_in(keep.lower(), ["first", "last"])
levels_to_drop = set()
level_values = [index.get_level_values(i) for i in range(index.nlevels)]
for i in range(index.nlevels):
level1 = level_values[i]
for j in range(i + 1, index.nlevels):
level2 = level_values[j]
if level1.name is None or level2.name is None or level1.name == level2.name:
if checks.is_index_equal(level1, level2, check_names=False):
if level1.name is None and level2.name is not None:
levels_to_drop.add(i)
elif level1.name is not None and level2.name is None:
levels_to_drop.add(j)
else:
if keep.lower() == "first":
levels_to_drop.add(j)
else:
levels_to_drop.add(i)
return index.droplevel(list(levels_to_drop))
@register_jitted(cache=True)
def align_arr_indices_nb(a: tp.Array1d, b: tp.Array1d) -> tp.Array1d:
"""Return indices required to align `a` to `b`."""
idxs = np.empty(b.shape[0], dtype=int_)
g = 0
for i in range(b.shape[0]):
for j in range(a.shape[0]):
if b[i] == a[j]:
idxs[g] = j
g += 1
break
return idxs
def align_index_to(index1: tp.Index, index2: tp.Index, jitted: tp.JittedOption = None) -> tp.IndexSlice:
"""Align `index1` to have the same shape as `index2` if they have any levels in common.
Returns index slice for the aligning."""
if not isinstance(index1, pd.MultiIndex):
index1 = pd.MultiIndex.from_arrays([index1])
if not isinstance(index2, pd.MultiIndex):
index2 = pd.MultiIndex.from_arrays([index2])
if checks.is_index_equal(index1, index2):
return pd.IndexSlice[:]
if len(index1) > len(index2):
raise ValueError("Longer index cannot be aligned to shorter index")
mapper = {}
for i in range(index1.nlevels):
name1 = index1.names[i]
for j in range(index2.nlevels):
name2 = index2.names[j]
if name1 is None or name2 is None or name1 == name2:
if set(index2.levels[j]).issubset(set(index1.levels[i])):
if i in mapper:
raise ValueError(f"There are multiple candidate levels with name {name1} in second index")
mapper[i] = j
continue
if name1 == name2 and name1 is not None:
raise ValueError(f"Level {name1} in second index contains values not in first index")
if len(mapper) == 0:
if len(index1) == len(index2):
return pd.IndexSlice[:]
raise ValueError("Cannot find common levels to align indexes")
factorized = []
for k, v in mapper.items():
factorized.append(
pd.factorize(
pd.concat(
(
index1.get_level_values(k).to_series(),
index2.get_level_values(v).to_series(),
)
)
)[0],
)
stacked = np.transpose(np.stack(factorized))
indices1 = stacked[: len(index1)]
indices2 = stacked[len(index1) :]
if len(indices1) < len(indices2):
if len(np.unique(indices1, axis=0)) != len(indices1):
raise ValueError("Cannot align indexes")
if len(index2) % len(index1) == 0:
tile_times = len(index2) // len(index1)
index1_tiled = np.tile(indices1, (tile_times, 1))
if np.array_equal(index1_tiled, indices2):
return pd.IndexSlice[np.tile(np.arange(len(index1)), tile_times)]
unique_indices = np.unique(stacked, axis=0, return_inverse=True)[1]
unique1 = unique_indices[: len(index1)]
unique2 = unique_indices[len(index1) :]
if len(indices1) == len(indices2):
if np.array_equal(unique1, unique2):
return pd.IndexSlice[:]
func = jit_reg.resolve_option(align_arr_indices_nb, jitted)
return pd.IndexSlice[func(unique1, unique2)]
def align_indexes(
*indexes: tp.MaybeTuple[tp.Index],
return_new_index: bool = False,
**kwargs,
) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]:
"""Align multiple indexes to each other with `align_index_to`."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
index_items = sorted([(i, indexes[i]) for i in range(len(indexes))], key=lambda x: len(x[1]))
index_slices = []
for i in range(len(index_items)):
index_slice = align_index_to(index_items[i][1], index_items[-1][1], **kwargs)
index_slices.append((index_items[i][0], index_slice))
index_slices = list(map(lambda x: x[1], sorted(index_slices, key=lambda x: x[0])))
if return_new_index:
new_index = stack_indexes(
*[indexes[i][index_slices[i]] for i in range(len(indexes))],
drop_duplicates=True,
)
return tuple(index_slices), new_index
return tuple(index_slices)
@register_jitted(cache=True)
def block_index_product_nb(
block_group_map1: tp.GroupMap,
block_group_map2: tp.GroupMap,
factorized1: tp.Array1d,
factorized2: tp.Array1d,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Return indices required for building a block-wise Cartesian product of two factorized indexes."""
group_idxs1, group_lens1 = block_group_map1
group_idxs2, group_lens2 = block_group_map2
group_start_idxs1 = np.cumsum(group_lens1) - group_lens1
group_start_idxs2 = np.cumsum(group_lens2) - group_lens2
matched1 = np.empty(len(factorized1), dtype=np.bool_)
matched2 = np.empty(len(factorized2), dtype=np.bool_)
indices1 = np.empty(len(factorized1) * len(factorized2), dtype=int_)
indices2 = np.empty(len(factorized1) * len(factorized2), dtype=int_)
k1 = 0
k2 = 0
for g1 in range(len(group_lens1)):
group_len1 = group_lens1[g1]
group_start1 = group_start_idxs1[g1]
for g2 in range(len(group_lens2)):
group_len2 = group_lens2[g2]
group_start2 = group_start_idxs2[g2]
for c1 in range(group_len1):
i = group_idxs1[group_start1 + c1]
for c2 in range(group_len2):
j = group_idxs2[group_start2 + c2]
if factorized1[i] == factorized2[j]:
matched1[i] = True
matched2[j] = True
indices1[k1] = i
indices2[k2] = j
k1 += 1
k2 += 1
if not np.all(matched1) or not np.all(matched2):
raise ValueError("Cannot match some block level values")
return indices1[:k1], indices2[:k2]
def cross_index_with(
index1: tp.Index,
index2: tp.Index,
return_new_index: bool = False,
) -> tp.Union[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Tuple[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Index]]:
"""Build a Cartesian product of one index with another while taking into account levels they have in common.
Returns index slices for the aligning."""
from vectorbtpro.base.grouping.nb import get_group_map_nb
index1_default = checks.is_default_index(index1, check_names=True)
index2_default = checks.is_default_index(index2, check_names=True)
if not isinstance(index1, pd.MultiIndex):
index1 = pd.MultiIndex.from_arrays([index1])
if not isinstance(index2, pd.MultiIndex):
index2 = pd.MultiIndex.from_arrays([index2])
if not index1_default and not index2_default and checks.is_index_equal(index1, index2):
if return_new_index:
new_index = stack_indexes(index1, index2, drop_duplicates=True)
return (pd.IndexSlice[:], pd.IndexSlice[:]), new_index
return pd.IndexSlice[:], pd.IndexSlice[:]
levels1 = []
levels2 = []
for i in range(index1.nlevels):
if checks.is_default_index(index1.get_level_values(i), check_names=True):
continue
for j in range(index2.nlevels):
if checks.is_default_index(index2.get_level_values(j), check_names=True):
continue
name1 = index1.names[i]
name2 = index2.names[j]
if name1 == name2:
if set(index2.levels[j]) == set(index1.levels[i]):
if i in levels1 or j in levels2:
raise ValueError(f"There are multiple candidate block levels with name {name1}")
levels1.append(i)
levels2.append(j)
continue
if name1 is not None:
raise ValueError(f"Candidate block level {name1} in both indexes has different values")
if len(levels1) == 0:
# Regular index product
indices1 = np.repeat(np.arange(len(index1)), len(index2))
indices2 = np.tile(np.arange(len(index2)), len(index1))
else:
# Block index product
index_levels1 = select_levels(index1, levels1)
index_levels2 = select_levels(index2, levels2)
block_levels1 = list(set(range(index1.nlevels)).difference(levels1))
block_levels2 = list(set(range(index2.nlevels)).difference(levels2))
if len(block_levels1) > 0:
index_block_levels1 = select_levels(index1, block_levels1)
else:
index_block_levels1 = pd.Index(np.full(len(index1), 0))
if len(block_levels2) > 0:
index_block_levels2 = select_levels(index2, block_levels2)
else:
index_block_levels2 = pd.Index(np.full(len(index2), 0))
factorized = pd.factorize(pd.concat((index_levels1.to_series(), index_levels2.to_series())))[0]
factorized1 = factorized[: len(index_levels1)]
factorized2 = factorized[len(index_levels1) :]
block_factorized1, block_unique1 = pd.factorize(index_block_levels1)
block_factorized2, block_unique2 = pd.factorize(index_block_levels2)
block_group_map1 = get_group_map_nb(block_factorized1, len(block_unique1))
block_group_map2 = get_group_map_nb(block_factorized2, len(block_unique2))
indices1, indices2 = block_index_product_nb(
block_group_map1,
block_group_map2,
factorized1,
factorized2,
)
if return_new_index:
new_index = stack_indexes(index1[indices1], index2[indices2], drop_duplicates=True)
return (pd.IndexSlice[indices1], pd.IndexSlice[indices2]), new_index
return pd.IndexSlice[indices1], pd.IndexSlice[indices2]
def cross_indexes(
*indexes: tp.MaybeTuple[tp.Index],
return_new_index: bool = False,
) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]:
"""Cross multiple indexes with `cross_index_with`."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
if len(indexes) == 2:
return cross_index_with(indexes[0], indexes[1], return_new_index=return_new_index)
index = None
index_slices = []
for i in range(len(indexes) - 2, -1, -1):
index1 = indexes[i]
if i == len(indexes) - 2:
index2 = indexes[i + 1]
else:
index2 = index
(index_slice1, index_slice2), index = cross_index_with(index1, index2, return_new_index=True)
if i == len(indexes) - 2:
index_slices.append(index_slice2)
else:
for j in range(len(index_slices)):
if isinstance(index_slices[j], slice):
index_slices[j] = np.arange(len(index2))[index_slices[j]]
index_slices[j] = index_slices[j][index_slice2]
index_slices.append(index_slice1)
if return_new_index:
return tuple(index_slices[::-1]), index
return tuple(index_slices[::-1])
OptionalLevelSequence = tp.Optional[tp.Sequence[tp.Union[None, tp.Level]]]
def pick_levels(
index: tp.Index,
required_levels: OptionalLevelSequence = None,
optional_levels: OptionalLevelSequence = None,
) -> tp.Tuple[tp.List[int], tp.List[int]]:
"""Pick optional and required levels and return their indices.
Raises an exception if index has less or more levels than expected."""
if required_levels is None:
required_levels = []
if optional_levels is None:
optional_levels = []
checks.assert_instance_of(index, pd.MultiIndex)
n_opt_set = len(list(filter(lambda x: x is not None, optional_levels)))
n_req_set = len(list(filter(lambda x: x is not None, required_levels)))
n_levels_left = index.nlevels - n_opt_set
if n_req_set < len(required_levels):
if n_levels_left != len(required_levels):
n_expected = len(required_levels) + n_opt_set
raise ValueError(f"Expected {n_expected} levels, found {index.nlevels}")
levels_left = list(range(index.nlevels))
_optional_levels = []
for level in optional_levels:
level_pos = None
if level is not None:
checks.assert_instance_of(level, (int, str))
if isinstance(level, str):
level_pos = index.names.index(level)
else:
level_pos = level
if level_pos < 0:
level_pos = index.nlevels + level_pos
levels_left.remove(level_pos)
_optional_levels.append(level_pos)
_required_levels = []
for level in required_levels:
level_pos = None
if level is not None:
checks.assert_instance_of(level, (int, str))
if isinstance(level, str):
level_pos = index.names.index(level)
else:
level_pos = level
if level_pos < 0:
level_pos = index.nlevels + level_pos
levels_left.remove(level_pos)
_required_levels.append(level_pos)
for i, level in enumerate(_required_levels):
if level is None:
_required_levels[i] = levels_left.pop(0)
return _required_levels, _optional_levels
def find_first_occurrence(index_value: tp.Any, index: tp.Index) -> int:
"""Return index of the first occurrence in `index`."""
loc = index.get_loc(index_value)
if isinstance(loc, slice):
return loc.start
elif isinstance(loc, list):
return loc[0]
elif isinstance(loc, np.ndarray):
return np.flatnonzero(loc)[0]
return loc
IndexApplierT = tp.TypeVar("IndexApplierT", bound="IndexApplier")
class IndexApplier(Base):
"""Abstract class that can apply a function on an index."""
def apply_to_index(self: IndexApplierT, apply_func: tp.Callable, *args, **kwargs) -> IndexApplierT:
"""Apply function `apply_func` on the index of the instance and return a new instance."""
raise NotImplementedError
def add_levels(
self: IndexApplierT,
*indexes: tp.Index,
on_top: bool = True,
drop_duplicates: tp.Optional[bool] = None,
keep: tp.Optional[str] = None,
drop_redundant: tp.Optional[bool] = None,
**kwargs,
) -> IndexApplierT:
"""Append or prepend levels using `stack_indexes`.
Set `on_top` to False to stack at bottom.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
if on_top:
return stack_indexes(
[*indexes, index],
drop_duplicates=drop_duplicates,
keep=keep,
drop_redundant=drop_redundant,
)
return stack_indexes(
[index, *indexes],
drop_duplicates=drop_duplicates,
keep=keep,
drop_redundant=drop_redundant,
)
return self.apply_to_index(_apply_func, **kwargs)
def drop_levels(
self: IndexApplierT,
levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
**kwargs,
) -> IndexApplierT:
"""Drop levels using `drop_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return drop_levels(index, levels, strict=strict)
return self.apply_to_index(_apply_func, **kwargs)
def rename_levels(
self: IndexApplierT,
mapper: tp.MaybeMappingSequence[tp.Level],
strict: bool = True,
**kwargs,
) -> IndexApplierT:
"""Rename levels using `rename_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return rename_levels(index, mapper, strict=strict)
return self.apply_to_index(_apply_func, **kwargs)
def select_levels(
self: IndexApplierT,
level_names: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
**kwargs,
) -> IndexApplierT:
"""Select levels using `select_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return select_levels(index, level_names, strict=strict)
return self.apply_to_index(_apply_func, **kwargs)
def drop_redundant_levels(self: IndexApplierT, **kwargs) -> IndexApplierT:
"""Drop any redundant levels using `drop_redundant_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return drop_redundant_levels(index)
return self.apply_to_index(_apply_func, **kwargs)
def drop_duplicate_levels(self: IndexApplierT, keep: tp.Optional[str] = None, **kwargs) -> IndexApplierT:
"""Drop any duplicate levels using `drop_duplicate_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return drop_duplicate_levels(index, keep=keep)
return self.apply_to_index(_apply_func, **kwargs)
</file>
<file path="base/indexing.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes and functions for indexing."""
import functools
from datetime import time
from functools import partial
import numpy as np
import pandas as pd
from pandas.tseries.offsets import BaseOffset
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, datetime_ as dt, datetime_nb as dt_nb
from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import hdict, merge_dicts
from vectorbtpro.utils.mapping import to_field_mapping
from vectorbtpro.utils.pickling import pdict
from vectorbtpro.utils.selection import PosSel, LabelSel
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"PandasIndexer",
"ExtPandasIndexer",
"hslice",
"get_index_points",
"get_index_ranges",
"get_idxs",
"index_dict",
"IdxSetter",
"IdxSetterFactory",
"IdxDict",
"IdxSeries",
"IdxFrame",
"IdxRecords",
"posidx",
"maskidx",
"lbidx",
"dtidx",
"dtcidx",
"pointidx",
"rangeidx",
"autoidx",
"rowidx",
"colidx",
"idx",
]
__pdoc__ = {}
class IndexingError(Exception):
"""Exception raised when an indexing error has occurred."""
IndexingBaseT = tp.TypeVar("IndexingBaseT", bound="IndexingBase")
class IndexingBase(Base):
"""Class that supports indexing through `IndexingBase.indexing_func`."""
def indexing_func(self: IndexingBaseT, pd_indexing_func: tp.Callable, **kwargs) -> IndexingBaseT:
"""Apply `pd_indexing_func` on all pandas objects in question and return a new instance of the class.
Should be overridden."""
raise NotImplementedError
def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None:
"""Apply `pd_indexing_setter_func` on all pandas objects in question.
Should be overridden."""
raise NotImplementedError
class LocBase(Base):
"""Class that implements location-based indexing."""
def __init__(
self,
indexing_func: tp.Callable,
indexing_setter_func: tp.Optional[tp.Callable] = None,
**kwargs,
) -> None:
self._indexing_func = indexing_func
self._indexing_setter_func = indexing_setter_func
self._indexing_kwargs = kwargs
@property
def indexing_func(self) -> tp.Callable:
"""Indexing function."""
return self._indexing_func
@property
def indexing_setter_func(self) -> tp.Optional[tp.Callable]:
"""Indexing setter function."""
return self._indexing_setter_func
@property
def indexing_kwargs(self) -> dict:
"""Keyword arguments passed to `LocBase.indexing_func`."""
return self._indexing_kwargs
def __getitem__(self, key: tp.Any) -> tp.Any:
raise NotImplementedError
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
raise NotImplementedError
def __iter__(self):
raise TypeError(f"'{type(self).__name__}' object is not iterable")
class pdLoc(LocBase):
"""Forwards a Pandas-like indexing operation to each Series/DataFrame and returns a new class instance."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
"""Pandas-like indexing operation."""
raise NotImplementedError
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
"""Pandas-like indexing setter operation."""
raise NotImplementedError
def __getitem__(self, key: tp.Any) -> tp.Any:
return self.indexing_func(partial(self.pd_indexing_func, key=key), **self.indexing_kwargs)
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
self.indexing_setter_func(partial(self.pd_indexing_setter_func, key=key, value=value), **self.indexing_kwargs)
class iLoc(pdLoc):
"""Forwards `pd.Series.iloc`/`pd.DataFrame.iloc` operation to each
Series/DataFrame and returns a new class instance."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
return obj.iloc.__getitem__(key)
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
obj.iloc.__setitem__(key, value)
class Loc(pdLoc):
"""Forwards `pd.Series.loc`/`pd.DataFrame.loc` operation to each
Series/DataFrame and returns a new class instance."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
return obj.loc.__getitem__(key)
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
obj.loc.__setitem__(key, value)
PandasIndexerT = tp.TypeVar("PandasIndexerT", bound="PandasIndexer")
class PandasIndexer(IndexingBase):
"""Implements indexing using `iloc`, `loc`, `xs` and `__getitem__`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.indexing import PandasIndexer
>>> class C(PandasIndexer):
... def __init__(self, df1, df2):
... self.df1 = df1
... self.df2 = df2
... super().__init__()
...
... def indexing_func(self, pd_indexing_func):
... return type(self)(
... pd_indexing_func(self.df1),
... pd_indexing_func(self.df2)
... )
>>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
>>> df2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]})
>>> c = C(df1, df2)
>>> c.iloc[:, 0]
<__main__.C object at 0x1a1cacbbe0>
>>> c.iloc[:, 0].df1
0 1
1 2
Name: a, dtype: int64
>>> c.iloc[:, 0].df2
0 5
1 6
Name: a, dtype: int64
```
"""
def __init__(self, **kwargs) -> None:
self._iloc = iLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs)
self._loc = Loc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs)
self._indexing_kwargs = kwargs
@property
def indexing_kwargs(self) -> dict:
"""Indexing keyword arguments."""
return self._indexing_kwargs
@property
def iloc(self) -> iLoc:
"""Purely integer-location based indexing for selection by position."""
return self._iloc
iloc.__doc__ = iLoc.__doc__
@property
def loc(self) -> Loc:
"""Purely label-location based indexer for selection by label."""
return self._loc
loc.__doc__ = Loc.__doc__
def xs(self: PandasIndexerT, *args, **kwargs) -> PandasIndexerT:
"""Forwards `pd.Series.xs`/`pd.DataFrame.xs`
operation to each Series/DataFrame and returns a new class instance."""
return self.indexing_func(lambda x: x.xs(*args, **kwargs), **self.indexing_kwargs)
def __getitem__(self: PandasIndexerT, key: tp.Any) -> PandasIndexerT:
def __getitem__func(x, _key=key):
return x.__getitem__(_key)
return self.indexing_func(__getitem__func, **self.indexing_kwargs)
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
def __setitem__func(x, _key=key, _value=value):
return x.__setitem__(_key, _value)
self.indexing_setter_func(__setitem__func, **self.indexing_kwargs)
def __iter__(self):
raise TypeError(f"'{type(self).__name__}' object is not iterable")
class xLoc(iLoc):
"""Subclass of `iLoc` that transforms an `Idxr`-based operation with
`get_idxs` to an `iLoc` operation."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
from vectorbtpro.base.indexes import get_index
if isinstance(key, tuple):
key = Idxr(*key)
index = get_index(obj, 0)
columns = get_index(obj, 1)
freq = dt.infer_index_freq(index)
row_idxs, col_idxs = get_idxs(key, index=index, columns=columns, freq=freq)
if isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2:
row_idxs = normalize_idxs(row_idxs, target_len=len(index))
if isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2:
col_idxs = normalize_idxs(col_idxs, target_len=len(columns))
if isinstance(obj, pd.Series):
if not isinstance(col_idxs, (slice, hslice)) or (
col_idxs.start is not None or col_idxs.stop is not None or col_idxs.step is not None
):
raise IndexingError("Too many indexers")
return obj.iloc.__getitem__(row_idxs)
return obj.iloc.__getitem__((row_idxs, col_idxs))
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
IdxSetter([(key, value)]).set_pd(obj)
class ExtPandasIndexer(PandasIndexer):
"""Extension of `PandasIndexer` that also implements indexing using `xLoc`."""
def __init__(self, **kwargs) -> None:
self._xloc = xLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs)
PandasIndexer.__init__(self, **kwargs)
@property
def xloc(self) -> xLoc:
"""`Idxr`-based indexing."""
return self._xloc
xloc.__doc__ = xLoc.__doc__
class ParamLoc(LocBase):
"""Access a group of columns by parameter using `pd.Series.loc`.
Uses `mapper` to establish link between columns and parameter values."""
@classmethod
def encode_key(cls, key: tp.Any):
"""Encode key."""
if isinstance(key, tuple):
return str(tuple(map(lambda k: k.item() if isinstance(k, np.generic) else k, key)))
key_str = str(key)
return str(key.item()) if isinstance(key, np.generic) else key_str
def __init__(
self,
mapper: tp.Series,
indexing_func: tp.Callable,
indexing_setter_func: tp.Optional[tp.Callable] = None,
level_name: tp.Level = None,
**kwargs,
) -> None:
checks.assert_instance_of(mapper, pd.Series)
if mapper.dtype == "O":
if isinstance(mapper.iloc[0], tuple):
mapper = mapper.apply(self.encode_key)
else:
mapper = mapper.astype(str)
self._mapper = mapper
self._level_name = level_name
LocBase.__init__(self, indexing_func, indexing_setter_func=indexing_setter_func, **kwargs)
@property
def mapper(self) -> tp.Series:
"""Mapper."""
return self._mapper
@property
def level_name(self) -> tp.Level:
"""Level name."""
return self._level_name
def get_idxs(self, key: tp.Any) -> tp.Array1d:
"""Get array of indices affected by this key."""
if self.mapper.dtype == "O":
if isinstance(key, (slice, hslice)):
start = self.encode_key(key.start) if key.start is not None else None
stop = self.encode_key(key.stop) if key.stop is not None else None
key = slice(start, stop, key.step)
elif isinstance(key, (list, np.ndarray)):
key = list(map(self.encode_key, key))
else:
key = self.encode_key(key)
mapper = pd.Series(np.arange(len(self.mapper.index)), index=self.mapper.values)
idxs = mapper.loc.__getitem__(key)
if isinstance(idxs, pd.Series):
idxs = idxs.values
return idxs
def __getitem__(self, key: tp.Any) -> tp.Any:
idxs = self.get_idxs(key)
is_multiple = isinstance(key, (slice, hslice, list, np.ndarray))
def pd_indexing_func(obj: tp.SeriesFrame) -> tp.MaybeSeriesFrame:
from vectorbtpro.base.indexes import drop_levels
new_obj = obj.iloc[:, idxs]
if not is_multiple:
if self.level_name is not None:
if checks.is_frame(new_obj):
if isinstance(new_obj.columns, pd.MultiIndex):
new_obj.columns = drop_levels(new_obj.columns, self.level_name)
return new_obj
return self.indexing_func(pd_indexing_func, **self.indexing_kwargs)
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
idxs = self.get_idxs(key)
def pd_indexing_setter_func(obj: tp.SeriesFrame) -> None:
obj.iloc[:, idxs] = value
return self.indexing_setter_func(pd_indexing_setter_func, **self.indexing_kwargs)
def indexing_on_mapper(
mapper: tp.Series,
ref_obj: tp.SeriesFrame,
pd_indexing_func: tp.Callable,
) -> tp.Optional[tp.Series]:
"""Broadcast `mapper` Series to `ref_obj` and perform pandas indexing using `pd_indexing_func`."""
from vectorbtpro.base.reshaping import broadcast_to
checks.assert_instance_of(mapper, pd.Series)
checks.assert_instance_of(ref_obj, (pd.Series, pd.DataFrame))
if isinstance(ref_obj, pd.Series):
range_mapper = broadcast_to(0, ref_obj)
else:
range_mapper = broadcast_to(np.arange(len(mapper.index))[None], ref_obj)
loced_range_mapper = pd_indexing_func(range_mapper)
new_mapper = mapper.iloc[loced_range_mapper.values[0]]
if checks.is_frame(loced_range_mapper):
return pd.Series(new_mapper.values, index=loced_range_mapper.columns, name=mapper.name)
elif checks.is_series(loced_range_mapper):
return pd.Series([new_mapper], index=[loced_range_mapper.name], name=mapper.name)
return None
def build_param_indexer(
param_names: tp.Sequence[str],
class_name: str = "ParamIndexer",
module_name: tp.Optional[str] = None,
) -> tp.Type[IndexingBase]:
"""A factory to create a class with parameter indexing.
Parameter indexer enables accessing a group of rows and columns by a parameter array (similar to `loc`).
This way, one can query index/columns by another Series called a parameter mapper, which is just a
`pd.Series` that maps columns (its index) to params (its values).
Parameter indexing is important, since querying by column/index labels alone is not always the best option.
For example, `pandas` doesn't let you query by list at a specific index/column level.
Args:
param_names (list of str): Names of the parameters.
class_name (str): Name of the generated class.
module_name (str): Name of the module to which the class should be bound.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.indexing import build_param_indexer, indexing_on_mapper
>>> MyParamIndexer = build_param_indexer(['my_param'])
>>> class C(MyParamIndexer):
... def __init__(self, df, param_mapper):
... self.df = df
... self._my_param_mapper = param_mapper
... super().__init__([param_mapper])
...
... def indexing_func(self, pd_indexing_func):
... return type(self)(
... pd_indexing_func(self.df),
... indexing_on_mapper(self._my_param_mapper, self.df, pd_indexing_func)
... )
>>> df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
>>> param_mapper = pd.Series(['First', 'Second'], index=['a', 'b'])
>>> c = C(df, param_mapper)
>>> c.my_param_loc['First'].df
0 1
1 2
Name: a, dtype: int64
>>> c.my_param_loc['Second'].df
0 3
1 4
Name: b, dtype: int64
>>> c.my_param_loc[['First', 'First', 'Second', 'Second']].df
a b
0 1 1 3 3
1 2 2 4 4
```
"""
class ParamIndexer(IndexingBase):
"""Class with parameter indexing."""
def __init__(
self,
param_mappers: tp.Sequence[tp.Series],
level_names: tp.Optional[tp.LevelSequence] = None,
**kwargs,
) -> None:
checks.assert_len_equal(param_names, param_mappers)
for i, param_name in enumerate(param_names):
level_name = level_names[i] if level_names is not None else None
_param_loc = ParamLoc(param_mappers[i], self.indexing_func, level_name=level_name, **kwargs)
setattr(self, f"_{param_name}_loc", _param_loc)
for i, param_name in enumerate(param_names):
def param_loc(self, _param_name=param_name) -> ParamLoc:
return getattr(self, f"_{_param_name}_loc")
param_loc.__doc__ = f"""Access a group of columns by parameter `{param_name}` using `pd.Series.loc`.
Forwards this operation to each Series/DataFrame and returns a new class instance.
"""
setattr(ParamIndexer, param_name + "_loc", property(param_loc))
ParamIndexer.__name__ = class_name
ParamIndexer.__qualname__ = ParamIndexer.__name__
if module_name is not None:
ParamIndexer.__module__ = module_name
return ParamIndexer
hsliceT = tp.TypeVar("hsliceT", bound="hslice")
@define
class hslice(DefineMixin):
"""Hashable slice."""
start: object = define.field()
"""Start."""
stop: object = define.field()
"""Stop."""
step: object = define.field()
"""Step."""
def __init__(self, start: object = MISSING, stop: object = MISSING, step: object = MISSING) -> None:
if start is not MISSING and stop is MISSING and step is MISSING:
stop = start
start, step = None, None
else:
if start is MISSING:
start = None
if stop is MISSING:
stop = None
if step is MISSING:
step = None
DefineMixin.__init__(self, start=start, stop=stop, step=step)
@classmethod
def from_slice(cls: tp.Type[hsliceT], slice_: slice) -> hsliceT:
"""Construct from a slice."""
return cls(slice_.start, slice_.stop, slice_.step)
def to_slice(self) -> slice:
"""Convert to a slice."""
return slice(self.start, self.stop, self.step)
class IdxrBase(Base):
"""Abstract class for resolving indices."""
def get(self, *args, **kwargs) -> tp.Any:
"""Get indices."""
raise NotImplementedError
@classmethod
def slice_indexer(
cls,
index: tp.Index,
slice_: tp.Slice,
closed_start: bool = True,
closed_end: bool = False,
) -> slice:
"""Compute the slice indexer for input labels and step."""
start = slice_.start
end = slice_.stop
if start is not None:
left_start = index.get_slice_bound(start, side="left")
right_start = index.get_slice_bound(start, side="right")
if left_start == right_start or not closed_start:
start = right_start
else:
start = left_start
if end is not None:
left_end = index.get_slice_bound(end, side="left")
right_end = index.get_slice_bound(end, side="right")
if left_end == right_end or closed_end:
end = right_end
else:
end = left_end
return slice(start, end, slice_.step)
def check_idxs(self, idxs: tp.MaybeIndexArray, check_minus_one: bool = False) -> None:
"""Check indices after resolving them."""
if isinstance(idxs, slice):
if idxs.start is not None and not checks.is_int(idxs.start):
raise TypeError("Start of a returned index slice must be an integer or None")
if idxs.stop is not None and not checks.is_int(idxs.stop):
raise TypeError("Stop of a returned index slice must be an integer or None")
if idxs.step is not None and not checks.is_int(idxs.step):
raise TypeError("Step of a returned index slice must be an integer or None")
if check_minus_one and idxs.start == -1:
raise ValueError("Range start index couldn't be matched")
elif check_minus_one and idxs.stop == -1:
raise ValueError("Range end index couldn't be matched")
elif checks.is_int(idxs):
if check_minus_one and idxs == -1:
raise ValueError("Index couldn't be matched")
elif checks.is_sequence(idxs) and not np.isscalar(idxs):
if len(idxs) == 0:
raise ValueError("No indices could be matched")
if not isinstance(idxs, np.ndarray):
raise ValueError(f"Indices must be a NumPy array, not {type(idxs)}")
if not np.issubdtype(idxs.dtype, np.integer) or np.issubdtype(idxs.dtype, np.bool_):
raise ValueError(f"Indices must be of integer data type, not {idxs.dtype}")
if check_minus_one and -1 in idxs:
raise ValueError("Some indices couldn't be matched")
if idxs.ndim not in (1, 2):
raise ValueError("Indices array must have either 1 or 2 dimensions")
if idxs.ndim == 2 and idxs.shape[1] != 2:
raise ValueError("Indices array provided as ranges must have exactly two columns")
else:
raise TypeError(
f"Indices must be an integer, a slice, a NumPy array, or a tuple of two NumPy arrays, not {type(idxs)}"
)
def normalize_idxs(idxs: tp.MaybeIndexArray, target_len: int) -> tp.Array1d:
"""Normalize indexes into a 1-dim integer array."""
if isinstance(idxs, hslice):
idxs = idxs.to_slice()
if isinstance(idxs, slice):
idxs = np.arange(target_len)[idxs]
if checks.is_int(idxs):
idxs = np.array([idxs])
if idxs.ndim == 2:
from vectorbtpro.base.merging import concat_arrays
idxs = concat_arrays(tuple(map(lambda x: np.arange(x[0], x[1]), idxs)))
if (idxs < 0).any():
idxs = np.where(idxs >= 0, idxs, target_len + idxs)
return idxs
class UniIdxr(IdxrBase):
"""Abstract class for resolving indices based on a single index."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
raise NotImplementedError
def __invert__(self):
def _op_func(x, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
idxs = np.setdiff1d(np.arange(len(index)), x)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self)
def __and__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.intersect1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __or__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.union1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __sub__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.setdiff1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __xor__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.setxor1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __lshift__(self, other):
def _op_func(x, y, index=None, freq=None):
if not checks.is_int(y):
raise TypeError("Second operand in __lshift__ must be an integer")
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
shifted = x - y
idxs = shifted[shifted >= 0]
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __rshift__(self, other):
def _op_func(x, y, index=None, freq=None):
if not checks.is_int(y):
raise TypeError("Second operand in __rshift__ must be an integer")
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
shifted = x + y
idxs = shifted[shifted >= 0]
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
@define
class UniIdxrOp(UniIdxr, DefineMixin):
"""Class for applying an operation to one or more indexers.
Produces a single set of indices."""
op_func: tp.Callable = define.field()
"""Operation function that takes the indices of each indexer (as `*args`), `index` (keyword argument),
and `freq` (keyword argument), and returns new indices."""
idxrs: tp.Tuple[object, ...] = define.field()
"""A tuple of one or more indexers."""
def __init__(self, op_func: tp.Callable, *idxrs) -> None:
if len(idxrs) == 1 and checks.is_iterable(idxrs[0]):
idxrs = idxrs[0]
DefineMixin.__init__(self, op_func=op_func, idxrs=idxrs)
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
idxr_indices = []
for idxr in self.idxrs:
if isinstance(idxr, IdxrBase):
checks.assert_instance_of(idxr, UniIdxr)
idxr_indices.append(idxr.get(index=index, freq=freq))
else:
idxr_indices.append(idxr)
return self.op_func(*idxr_indices, index=index, freq=freq)
@define
class PosIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as integer positions."""
value: tp.Union[None, tp.MaybeSequence[tp.MaybeSequence[int]], tp.Slice] = define.field()
"""One or more integer positions."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
idxs = self.value
if checks.is_sequence(idxs) and not np.isscalar(idxs):
idxs = np.asarray(idxs)
if isinstance(idxs, hslice):
idxs = idxs.to_slice()
self.check_idxs(idxs)
return idxs
@define
class MaskIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as a mask."""
value: tp.Union[None, tp.Sequence[bool]] = define.field()
"""Mask."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
idxs = np.flatnonzero(self.value)
self.check_idxs(idxs)
return idxs
@define
class LabelIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as labels."""
value: tp.Union[None, tp.MaybeSequence[tp.Label], tp.Slice] = define.field()
"""One or more labels."""
closed_start: bool = define.field(default=True)
"""Whether slice start should be inclusive."""
closed_end: bool = define.field(default=True)
"""Whether slice end should be inclusive."""
level: tp.MaybeLevelSequence = define.field(default=None)
"""One or more levels."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
if index is None:
raise ValueError("Index is required")
if self.level is not None:
from vectorbtpro.base.indexes import select_levels
index = select_levels(index, self.level)
if isinstance(self.value, (slice, hslice)):
idxs = self.slice_indexer(
index,
self.value,
closed_start=self.closed_start,
closed_end=self.closed_end,
)
elif (checks.is_sequence(self.value) and not np.isscalar(self.value)) and (
not isinstance(index, pd.MultiIndex)
or (isinstance(index, pd.MultiIndex) and isinstance(self.value[0], tuple))
):
idxs = index.get_indexer_for(self.value)
else:
idxs = index.get_loc(self.value)
if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_):
idxs = np.flatnonzero(idxs)
self.check_idxs(idxs, check_minus_one=True)
return idxs
@define
class DatetimeIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as datetime-like objects."""
value: tp.Union[None, tp.MaybeSequence[tp.DatetimeLike], tp.Slice] = define.field()
"""One or more datetime-like objects."""
closed_start: bool = define.field(default=True)
"""Whether slice start should be inclusive."""
closed_end: bool = define.field(default=False)
"""Whether slice end should be inclusive."""
indexer_method: tp.Optional[str] = define.field(default="bfill")
"""Method for `pd.Index.get_indexer`.
Allows two additional values: "before" and "after"."""
below_to_zero: bool = define.field(default=False)
"""Whether to place 0 instead of -1 if `DatetimeIdxr.value` is below the first index."""
above_to_len: bool = define.field(default=False)
"""Whether to place `len(index)` instead of -1 if `DatetimeIdxr.value` is above the last index."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
if index is None:
raise ValueError("Index is required")
index = dt.prepare_dt_index(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if not index.is_unique:
raise ValueError("Datetime index must be unique")
if not index.is_monotonic_increasing:
raise ValueError("Datetime index must be monotonically increasing")
if isinstance(self.value, (slice, hslice)):
start = dt.try_align_dt_to_index(self.value.start, index)
stop = dt.try_align_dt_to_index(self.value.stop, index)
new_value = slice(start, stop, self.value.step)
idxs = self.slice_indexer(index, new_value, closed_start=self.closed_start, closed_end=self.closed_end)
elif checks.is_sequence(self.value) and not np.isscalar(self.value):
new_value = dt.try_align_to_dt_index(self.value, index)
idxs = index.get_indexer(new_value, method=self.indexer_method)
if self.below_to_zero:
idxs = np.where(new_value < index[0], 0, idxs)
if self.above_to_len:
idxs = np.where(new_value > index[-1], len(index), idxs)
else:
new_value = dt.try_align_dt_to_index(self.value, index)
if new_value < index[0] and self.below_to_zero:
idxs = 0
elif new_value > index[-1] and self.above_to_len:
idxs = len(index)
else:
if self.indexer_method is None or new_value in index:
idxs = index.get_loc(new_value)
if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_):
idxs = np.flatnonzero(idxs)
else:
indexer_method = self.indexer_method
if indexer_method is not None:
indexer_method = indexer_method.lower()
if indexer_method == "before":
new_value = new_value - pd.Timedelta(1, "ns")
indexer_method = "ffill"
elif indexer_method == "after":
new_value = new_value + pd.Timedelta(1, "ns")
indexer_method = "bfill"
idxs = index.get_indexer([new_value], method=indexer_method)[0]
self.check_idxs(idxs, check_minus_one=True)
return idxs
@define
class DTCIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as datetime-like components."""
value: tp.Union[None, tp.MaybeSequence[tp.DTCLike], tp.Slice] = define.field()
"""One or more datetime-like components."""
parse_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `vectorbtpro.utils.datetime_.DTC.parse`."""
closed_start: bool = define.field(default=True)
"""Whether slice start should be inclusive."""
closed_end: bool = define.field(default=False)
"""Whether slice end should be inclusive."""
jitted: tp.JittedOption = define.field(default=None)
"""Jitting option passed to `vectorbtpro.utils.datetime_nb.index_matches_dtc_nb`
and `vectorbtpro.utils.datetime_nb.index_within_dtc_range_nb`."""
@staticmethod
def get_dtc_namedtuple(value: tp.Optional[tp.DTCLike] = None, **parse_kwargs) -> dt.DTCNT:
"""Convert a value to a `vectorbtpro.utils.datetime_.DTCNT` instance."""
if value is None:
return dt.DTC().to_namedtuple()
if isinstance(value, dt.DTC):
return value.to_namedtuple()
if isinstance(value, dt.DTCNT):
return value
return dt.DTC.parse(value, **parse_kwargs).to_namedtuple()
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
parse_kwargs = self.parse_kwargs
if parse_kwargs is None:
parse_kwargs = {}
if index is None:
raise ValueError("Index is required")
index = dt.prepare_dt_index(index)
ns_index = dt.to_ns(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if not index.is_unique:
raise ValueError("Datetime index must be unique")
if not index.is_monotonic_increasing:
raise ValueError("Datetime index must be monotonically increasing")
if isinstance(self.value, (slice, hslice)):
if self.value.step is not None:
raise ValueError("Step must be None")
if self.value.start is None and self.value.stop is None:
return slice(None, None, None)
start_dtc = self.get_dtc_namedtuple(self.value.start, **parse_kwargs)
end_dtc = self.get_dtc_namedtuple(self.value.stop, **parse_kwargs)
func = jit_reg.resolve_option(dt_nb.index_within_dtc_range_nb, self.jitted)
mask = func(ns_index, start_dtc, end_dtc, closed_start=self.closed_start, closed_end=self.closed_end)
elif checks.is_sequence(self.value) and not np.isscalar(self.value):
func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted)
dtcs = map(lambda x: self.get_dtc_namedtuple(x, **parse_kwargs), self.value)
masks = map(lambda x: func(ns_index, x), dtcs)
mask = functools.reduce(np.logical_or, masks)
else:
dtc = self.get_dtc_namedtuple(self.value, **parse_kwargs)
func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted)
mask = func(ns_index, dtc)
return MaskIdxr(mask).get(index=index, freq=freq)
@define
class PointIdxr(UniIdxr, DefineMixin):
"""Class for resolving index points."""
every: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Frequency either as an integer or timedelta.
Gets translated into `on` array by creating a range. If integer, an index sequence from `start` to `end`
(exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence from
`start` to `end` (inclusive) is created and 'labels' as `kind` is used.
If `at_time` is not None and `every` and `on` are None, `every` defaults to one day."""
normalize_every: bool = define.field(default=False)
"""Normalize start/end dates to midnight before generating date range."""
at_time: tp.Optional[tp.TimeLike] = define.field(default=None)
"""Time of the day either as a (human-readable) string or `datetime.time`.
Every datetime in `on` gets floored to the daily frequency, while `at_time` gets converted into
a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_delta`.
Index must be datetime-like."""
start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None)
"""Start index/date.
If (human-readable) string, gets converted into a datetime.
If `every` is None, gets used to filter the final index array."""
end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None)
"""End index/date.
If (human-readable) string, gets converted into a datetime.
If `every` is None, gets used to filter the final index array."""
exact_start: bool = define.field(default=False)
"""Whether the first index should be exactly `start`.
Depending on `every`, the first index picked by `pd.date_range` may happen after `start`.
In such a case, `start` gets injected before the first index generated by `pd.date_range`."""
on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None)
"""Index/label or a sequence of such.
Gets converted into datetime format whenever possible."""
add_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Offset to be added to each in `on`.
Gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`."""
kind: tp.Optional[str] = define.field(default=None)
"""Kind of data in `on`: indices or labels.
If None, gets assigned to `indices` if `on` contains integer data, otherwise to `labels`.
If `kind` is 'labels', `on` gets converted into indices using `pd.Index.get_indexer`.
Prior to this, gets its timezone aligned to the timezone of the index. If `kind` is 'indices',
`on` gets wrapped with NumPy."""
indexer_method: str = define.field(default="bfill")
"""Method for `pd.Index.get_indexer`.
Allows two additional values: "before" and "after"."""
indexer_tolerance: tp.Optional[tp.Union[int, tp.TimedeltaLike, tp.IndexLike]] = define.field(default=None)
"""Tolerance for `pd.Index.get_indexer`.
If `at_time` is set and `indexer_method` is neither exact nor nearest, `indexer_tolerance`
becomes such that the next element must be within the current day."""
skip_not_found: bool = define.field(default=True)
"""Whether to drop indices that are -1 (not found)."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if index is None:
raise ValueError("Index is required")
idxs = get_index_points(index, **self.asdict())
self.check_idxs(idxs, check_minus_one=True)
return idxs
point_idxr_defaults = {a.name: a.default for a in PointIdxr.fields}
def get_index_points(
index: tp.Index,
every: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["every"],
normalize_every: bool = point_idxr_defaults["normalize_every"],
at_time: tp.Optional[tp.TimeLike] = point_idxr_defaults["at_time"],
start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["start"],
end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["end"],
exact_start: bool = point_idxr_defaults["exact_start"],
on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = point_idxr_defaults["on"],
add_delta: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["add_delta"],
kind: tp.Optional[str] = point_idxr_defaults["kind"],
indexer_method: str = point_idxr_defaults["indexer_method"],
indexer_tolerance: str = point_idxr_defaults["indexer_tolerance"],
skip_not_found: bool = point_idxr_defaults["skip_not_found"],
) -> tp.Array1d:
"""Translate indices or labels into index points.
See `PointIdxr` for arguments.
Usage:
* Provide nothing to generate at the beginning:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020-01", "2020-02", freq="1d")
>>> vbt.get_index_points(index)
array([0])
```
* Provide `every` as an integer frequency to generate index points using NumPy:
```pycon
>>> # Generate a point every five rows
>>> vbt.get_index_points(index, every=5)
array([ 0, 5, 10, 15, 20, 25, 30])
>>> # Generate a point every five rows starting at 6th row
>>> vbt.get_index_points(index, every=5, start=5)
array([ 5, 10, 15, 20, 25, 30])
>>> # Generate a point every five rows from 6th to 16th row
>>> vbt.get_index_points(index, every=5, start=5, end=15)
array([ 5, 10])
```
* Provide `every` as a time delta frequency to generate index points using Pandas:
```pycon
>>> # Generate a point every week
>>> vbt.get_index_points(index, every="W")
array([ 4, 11, 18, 25])
>>> # Generate a point every second day of the week
>>> vbt.get_index_points(index, every="W", add_delta="2d")
array([ 6, 13, 20, 27])
>>> # Generate a point every week, starting at 11th row
>>> vbt.get_index_points(index, every="W", start=10)
array([11, 18, 25])
>>> # Generate a point every week, starting exactly at 11th row
>>> vbt.get_index_points(index, every="W", start=10, exact_start=True)
array([10, 11, 18, 25])
>>> # Generate a point every week, starting at 2020-01-10
>>> vbt.get_index_points(index, every="W", start="2020-01-10")
array([11, 18, 25])
```
* Instead of using `every`, provide indices explicitly:
```pycon
>>> # Generate one point
>>> vbt.get_index_points(index, on="2020-01-07")
array([6])
>>> # Generate multiple points
>>> vbt.get_index_points(index, on=["2020-01-07", "2020-01-14"])
array([ 6, 13])
```
"""
index = dt.prepare_dt_index(index)
if on is not None and isinstance(on, str):
on = dt.try_align_dt_to_index(on, index)
if start is not None and isinstance(start, str):
start = dt.try_align_dt_to_index(start, index)
if end is not None and isinstance(end, str):
end = dt.try_align_dt_to_index(end, index)
if every is not None and not checks.is_int(every):
every = dt.to_freq(every)
start_used = False
end_used = False
if at_time is not None and every is None and on is None:
every = pd.Timedelta(days=1)
if every is not None:
start_used = True
end_used = True
if checks.is_int(every):
if start is None:
start = 0
if end is None:
end = len(index)
on = np.arange(start, end, every)
kind = "indices"
else:
if start is None:
start = 0
if checks.is_int(start):
start_date = index[start]
else:
start_date = start
if end is None:
end = len(index) - 1
if checks.is_int(end):
end_date = index[end]
else:
end_date = end
on = dt.date_range(
start_date,
end_date,
freq=every,
tz=index.tz,
normalize=normalize_every,
inclusive="both",
)
if exact_start and on[0] > start_date:
on = on.insert(0, start_date)
kind = "labels"
if kind is None:
if on is None:
if start is not None:
if checks.is_int(start):
kind = "indices"
else:
kind = "labels"
else:
kind = "indices"
else:
on = dt.prepare_dt_index(on)
if pd.api.types.is_integer_dtype(on):
kind = "indices"
else:
kind = "labels"
checks.assert_in(kind, ("indices", "labels"))
if on is None:
if start is not None:
on = start
start_used = True
else:
if kind.lower() in ("labels",):
on = index
else:
on = np.arange(len(index))
on = dt.prepare_dt_index(on)
if at_time is not None:
checks.assert_instance_of(on, pd.DatetimeIndex)
on = on.floor("D")
add_time_delta = dt.time_to_timedelta(at_time)
if indexer_tolerance is None:
indexer_method = indexer_method.lower()
if indexer_method in ("pad", "ffill"):
indexer_tolerance = add_time_delta
elif indexer_method in ("backfill", "bfill"):
indexer_tolerance = pd.Timedelta(days=1) - pd.Timedelta(1, "ns") - add_time_delta
if add_delta is None:
add_delta = add_time_delta
else:
add_delta += add_time_delta
if add_delta is not None:
on += dt.to_freq(add_delta)
if kind.lower() == "labels":
on = dt.try_align_to_dt_index(on, index)
if indexer_method is not None:
indexer_method = indexer_method.lower()
if indexer_method == "before":
on = on - pd.Timedelta(1, "ns")
indexer_method = "ffill"
elif indexer_method == "after":
on = on + pd.Timedelta(1, "ns")
indexer_method = "bfill"
index_points = index.get_indexer(on, method=indexer_method, tolerance=indexer_tolerance)
else:
index_points = np.asarray(on)
if start is not None and not start_used:
if not checks.is_int(start):
start = index.get_indexer([start], method="bfill").item(0)
index_points = index_points[index_points >= start]
if end is not None and not end_used:
if not checks.is_int(end):
end = index.get_indexer([end], method="ffill").item(0)
index_points = index_points[index_points <= end]
else:
index_points = index_points[index_points < end]
if skip_not_found:
index_points = index_points[index_points != -1]
return index_points
@define
class RangeIdxr(UniIdxr, DefineMixin):
"""Class for resolving index ranges."""
every: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Frequency either as an integer or timedelta.
Gets translated into `start` and `end` arrays by creating a range. If integer, an index sequence from `start`
to `end` (exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence
from `start` to `end` (inclusive) is created and 'bounds' as `kind` is used.
If `start_time` and `end_time` are not None and `every`, `start`, and `end` are None,
`every` defaults to one day."""
normalize_every: bool = define.field(default=False)
"""Normalize start/end dates to midnight before generating date range."""
split_every: bool = define.field(default=True)
"""Whether to split the sequence generated using `every` into `start` and `end` arrays.
After creation, and if `split_every` is True, an index range is created from each pair of elements in
the generated sequence. Otherwise, the entire sequence is assigned to `start` and `end`, and only time
and delta instructions can be used to further differentiate between them.
Forced to False if `every`, `start_time`, and `end_time` are not None and `fixed_start` is False."""
start_time: tp.Optional[tp.TimeLike] = define.field(default=None)
"""Start time of the day either as a (human-readable) string or `datetime.time`.
Every datetime in `start` gets floored to the daily frequency, while `start_time` gets converted into
a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_start_delta`.
Index must be datetime-like."""
end_time: tp.Optional[tp.TimeLike] = define.field(default=None)
"""End time of the day either as a (human-readable) string or `datetime.time`.
Every datetime in `end` gets floored to the daily frequency, while `end_time` gets converted into
a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_end_delta`.
Index must be datetime-like."""
lookback_period: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Lookback period either as an integer or offset.
If `lookback_period` is set, `start` becomes `end-lookback_period`. If `every` is not None,
the sequence is generated from `start+lookback_period` to `end` and then assigned to `end`.
If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.
If integer, gets multiplied by the frequency of the index if the index is not integer."""
start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None)
"""Start index/label or a sequence of such.
Gets converted into datetime format whenever possible.
Gets broadcasted together with `end`."""
end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None)
"""End index/label or a sequence of such.
Gets converted into datetime format whenever possible.
Gets broadcasted together with `start`."""
exact_start: bool = define.field(default=False)
"""Whether the first index in the `start` array should be exactly `start`.
Depending on `every`, the first index picked by `pd.date_range` may happen after `start`.
In such a case, `start` gets injected before the first index generated by `pd.date_range`.
Cannot be used together with `lookback_period`."""
fixed_start: bool = define.field(default=False)
"""Whether all indices in the `start` array should be exactly `start`.
Works only together with `every`.
Cannot be used together with `lookback_period`."""
closed_start: bool = define.field(default=True)
"""Whether `start` should be inclusive."""
closed_end: bool = define.field(default=False)
"""Whether `end` should be inclusive."""
add_start_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Offset to be added to each in `start`.
If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`."""
add_end_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Offset to be added to each in `end`.
If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`."""
kind: tp.Optional[str] = define.field(default=None)
"""Kind of data in `on`: indices, labels or bounds.
If None, gets assigned to `indices` if `start` and `end` contain integer data, to `bounds`
if `start`, `end`, and index are datetime-like, otherwise to `labels`.
If `kind` is 'labels', `start` and `end` get converted into indices using `pd.Index.get_indexer`.
Prior to this, get their timezone aligned to the timezone of the index. If `kind` is 'indices',
`start` and `end` get wrapped with NumPy. If kind` is 'bounds',
`vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges` is used."""
skip_not_found: bool = define.field(default=True)
"""Whether to drop indices that are -1 (not found)."""
jitted: tp.JittedOption = define.field(default=None)
"""Jitting option passed to `vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges`."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if index is None:
raise ValueError("Index is required")
from vectorbtpro.base.merging import column_stack_arrays
start_idxs, end_idxs = get_index_ranges(index, index_freq=freq, **self.asdict())
idxs = column_stack_arrays((start_idxs, end_idxs))
self.check_idxs(idxs, check_minus_one=True)
return idxs
range_idxr_defaults = {a.name: a.default for a in RangeIdxr.fields}
def get_index_ranges(
index: tp.Index,
index_freq: tp.Optional[tp.FrequencyLike] = None,
every: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["every"],
normalize_every: bool = range_idxr_defaults["normalize_every"],
split_every: bool = range_idxr_defaults["split_every"],
start_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["start_time"],
end_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["end_time"],
lookback_period: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["lookback_period"],
start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["start"],
end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["end"],
exact_start: bool = range_idxr_defaults["exact_start"],
fixed_start: bool = range_idxr_defaults["fixed_start"],
closed_start: bool = range_idxr_defaults["closed_start"],
closed_end: bool = range_idxr_defaults["closed_end"],
add_start_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_start_delta"],
add_end_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_end_delta"],
kind: tp.Optional[str] = range_idxr_defaults["kind"],
skip_not_found: bool = range_idxr_defaults["skip_not_found"],
jitted: tp.JittedOption = range_idxr_defaults["jitted"],
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Translate indices, labels, or bounds into index ranges.
See `RangeIdxr` for arguments.
Usage:
* Provide nothing to generate one largest index range:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020-01", "2020-02", freq="1d")
>>> np.column_stack(vbt.get_index_ranges(index))
array([[ 0, 32]])
```
* Provide `every` as an integer frequency to generate index ranges using NumPy:
```pycon
>>> # Generate a range every five rows
>>> np.column_stack(vbt.get_index_ranges(index, every=5))
array([[ 0, 5],
[ 5, 10],
[10, 15],
[15, 20],
[20, 25],
[25, 30]])
>>> # Generate a range every five rows, starting at 6th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every=5,
... start=5
... ))
array([[ 5, 10],
[10, 15],
[15, 20],
[20, 25],
[25, 30]])
>>> # Generate a range every five rows from 6th to 16th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every=5,
... start=5,
... end=15
... ))
array([[ 5, 10],
[10, 15]])
```
* Provide `every` as a time delta frequency to generate index ranges using Pandas:
```pycon
>>> # Generate a range every week
>>> np.column_stack(vbt.get_index_ranges(index, every="W"))
array([[ 4, 11],
[11, 18],
[18, 25]])
>>> # Generate a range every second day of the week
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... add_start_delta="2d"
... ))
array([[ 6, 11],
[13, 18],
[20, 25]])
>>> # Generate a range every week, starting at 11th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start=10
... ))
array([[11, 18],
[18, 25]])
>>> # Generate a range every week, starting exactly at 11th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start=10,
... exact_start=True
... ))
array([[10, 11],
[11, 18],
[18, 25]])
>>> # Generate a range every week, starting at 2020-01-10
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start="2020-01-10"
... ))
array([[11, 18],
[18, 25]])
>>> # Generate a range every week, each starting at 2020-01-10
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start="2020-01-10",
... fixed_start=True
... ))
array([[11, 18],
[11, 25]])
>>> # Generate an expanding range that increments by week
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start=0,
... exact_start=True,
... fixed_start=True
... ))
array([[ 0, 4],
[ 0, 11],
[ 0, 18],
[ 0, 25]])
```
* Use a look-back period (instead of an end index):
```pycon
>>> # Generate a range every week, looking 5 days back
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... lookback_period=5
... ))
array([[ 6, 11],
[13, 18],
[20, 25]])
>>> # Generate a range every week, looking 2 weeks back
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... lookback_period="2W"
... ))
array([[ 0, 11],
[ 4, 18],
[11, 25]])
```
* Instead of using `every`, provide start and end indices explicitly:
```pycon
>>> # Generate one range
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start="2020-01-01",
... end="2020-01-07"
... ))
array([[0, 6]])
>>> # Generate ranges between multiple dates
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start=["2020-01-01", "2020-01-07"],
... end=["2020-01-07", "2020-01-14"]
... ))
array([[ 0, 6],
[ 6, 13]])
>>> # Generate ranges with a fixed start
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start="2020-01-01",
... end=["2020-01-07", "2020-01-14"]
... ))
array([[ 0, 6],
[ 0, 13]])
```
* Use `closed_start` and `closed_end` to exclude any of the bounds:
```pycon
>>> # Generate ranges between multiple dates
>>> # by excluding the start date and including the end date
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start=["2020-01-01", "2020-01-07"],
... end=["2020-01-07", "2020-01-14"],
... closed_start=False,
... closed_end=True
... ))
array([[ 1, 7],
[ 7, 14]])
```
"""
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.resampling.base import Resampler
index = dt.prepare_dt_index(index)
if isinstance(index, pd.DatetimeIndex):
if start is not None:
start = dt.try_align_to_dt_index(start, index)
if isinstance(start, pd.DatetimeIndex):
start = start.tz_localize(None)
if end is not None:
end = dt.try_align_to_dt_index(end, index)
if isinstance(end, pd.DatetimeIndex):
end = end.tz_localize(None)
naive_index = index.tz_localize(None)
else:
if start is not None:
if not isinstance(start, pd.Index):
try:
start = pd.Index(start)
except Exception as e:
start = pd.Index([start])
if end is not None:
if not isinstance(end, pd.Index):
try:
end = pd.Index(end)
except Exception as e:
end = pd.Index([end])
naive_index = index
if every is not None and not checks.is_int(every):
every = dt.to_freq(every)
if lookback_period is not None and not checks.is_int(lookback_period):
lookback_period = dt.to_freq(lookback_period)
if fixed_start and lookback_period is not None:
raise ValueError("Cannot use fixed_start and lookback_period together")
if exact_start and lookback_period is not None:
raise ValueError("Cannot use exact_start and lookback_period together")
if start_time is not None or end_time is not None:
if every is None and start is None and end is None:
every = pd.Timedelta(days=1)
if every is not None:
if not fixed_start:
if start_time is None and end_time is not None:
start_time = time(0, 0, 0, 0)
closed_start = True
if start_time is not None and end_time is None:
end_time = time(0, 0, 0, 0)
closed_end = False
if start_time is not None and end_time is not None and not fixed_start:
split_every = False
if checks.is_int(every):
if start is None:
start = 0
else:
start = start[0]
if end is None:
end = len(naive_index)
else:
end = end[-1]
if closed_end:
end -= 1
if lookback_period is None:
new_index = np.arange(start, end + 1, every)
if not split_every:
start = end = new_index
else:
if fixed_start:
start = np.full(len(new_index) - 1, new_index[0])
else:
start = new_index[:-1]
end = new_index[1:]
else:
end = np.arange(start + lookback_period, end + 1, every)
start = end - lookback_period
kind = "indices"
lookback_period = None
else:
if start is None:
start = 0
else:
start = start[0]
if checks.is_int(start):
start_date = naive_index[start]
else:
start_date = start
if end is None:
end = len(naive_index) - 1
else:
end = end[-1]
if checks.is_int(end):
end_date = naive_index[end]
else:
end_date = end
if lookback_period is None:
new_index = dt.date_range(
start_date,
end_date,
freq=every,
normalize=normalize_every,
inclusive="both",
)
if exact_start and new_index[0] > start_date:
new_index = new_index.insert(0, start_date)
if not split_every:
start = end = new_index
else:
if fixed_start:
start = repeat_index(new_index[[0]], len(new_index) - 1)
else:
start = new_index[:-1]
end = new_index[1:]
else:
if checks.is_int(lookback_period):
lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq)
if isinstance(lookback_period, BaseOffset):
end = dt.date_range(
start_date,
end_date,
freq=every,
normalize=normalize_every,
inclusive="both",
)
start = end - lookback_period
start_mask = start >= start_date
start = start[start_mask]
end = end[start_mask]
else:
end = dt.date_range(
start_date + lookback_period,
end_date,
freq=every,
normalize=normalize_every,
inclusive="both",
)
start = end - lookback_period
kind = "bounds"
lookback_period = None
if kind is None:
if start is None and end is None:
kind = "indices"
else:
if start is not None:
ref_index = start
if end is not None:
ref_index = end
if pd.api.types.is_integer_dtype(ref_index):
kind = "indices"
elif isinstance(ref_index, pd.DatetimeIndex) and isinstance(naive_index, pd.DatetimeIndex):
kind = "bounds"
else:
kind = "labels"
checks.assert_in(kind, ("indices", "labels", "bounds"))
if end is None:
if kind.lower() in ("labels", "bounds"):
end = pd.Index([naive_index[-1]])
else:
end = pd.Index([len(naive_index)])
if start is not None and lookback_period is not None:
raise ValueError("Cannot use start and lookback_period together")
if start is None:
if lookback_period is None:
if kind.lower() in ("labels", "bounds"):
start = pd.Index([naive_index[0]])
else:
start = pd.Index([0])
else:
if checks.is_int(lookback_period) and not pd.api.types.is_integer_dtype(end):
lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq)
start = end - lookback_period
if len(start) == 1 and len(end) > 1:
start = repeat_index(start, len(end))
elif len(start) > 1 and len(end) == 1:
end = repeat_index(end, len(start))
checks.assert_len_equal(start, end)
if start_time is not None:
checks.assert_instance_of(start, pd.DatetimeIndex)
start = start.floor("D")
add_start_time_delta = dt.time_to_timedelta(start_time)
if add_start_delta is None:
add_start_delta = add_start_time_delta
else:
add_start_delta += add_start_time_delta
else:
add_start_time_delta = None
if end_time is not None:
checks.assert_instance_of(end, pd.DatetimeIndex)
end = end.floor("D")
add_end_time_delta = dt.time_to_timedelta(end_time)
if add_start_time_delta is not None:
if add_end_time_delta < add_start_delta:
add_end_time_delta += pd.Timedelta(days=1)
if add_end_delta is None:
add_end_delta = add_end_time_delta
else:
add_end_delta += add_end_time_delta
if add_start_delta is not None:
start += dt.to_freq(add_start_delta)
if add_end_delta is not None:
end += dt.to_freq(add_end_delta)
if kind.lower() == "bounds":
range_starts, range_ends = Resampler.map_bounds_to_source_ranges(
source_index=naive_index.values,
target_lbound_index=start.values,
target_rbound_index=end.values,
closed_lbound=closed_start,
closed_rbound=closed_end,
skip_not_found=skip_not_found,
jitted=jitted,
)
else:
if kind.lower() == "labels":
range_starts = np.empty(len(start), dtype=int_)
range_ends = np.empty(len(end), dtype=int_)
range_index = pd.Series(np.arange(len(naive_index)), index=naive_index)
for i in range(len(range_starts)):
selected_range = range_index[start[i] : end[i]]
if len(selected_range) > 0 and not closed_start and selected_range.index[0] == start[i]:
selected_range = selected_range.iloc[1:]
if len(selected_range) > 0 and not closed_end and selected_range.index[-1] == end[i]:
selected_range = selected_range.iloc[:-1]
if len(selected_range) > 0:
range_starts[i] = selected_range.iloc[0]
range_ends[i] = selected_range.iloc[-1]
else:
range_starts[i] = -1
range_ends[i] = -1
else:
if not closed_start:
start = start + 1
if closed_end:
end = end + 1
range_starts = np.asarray(start)
range_ends = np.asarray(end)
if skip_not_found:
valid_mask = (range_starts != -1) & (range_ends != -1)
range_starts = range_starts[valid_mask]
range_ends = range_ends[valid_mask]
if np.any(range_starts >= range_ends):
raise ValueError("Some start indices are equal to or higher than end indices")
return range_starts, range_ends
@define
class AutoIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices, datetime-like objects, frequency-like objects, and labels for one axis."""
value: tp.Union[
None,
tp.PosSel,
tp.LabelSel,
tp.MaybeSequence[tp.MaybeSequence[int]],
tp.MaybeSequence[tp.Label],
tp.MaybeSequence[tp.DatetimeLike],
tp.MaybeSequence[tp.DTCLike],
tp.FrequencyLike,
tp.Slice,
] = define.field()
"""One or more integer indices, datetime-like objects, frequency-like objects, or labels.
Can also be an instance of `vectorbtpro.utils.selection.PosSel` holding position(s)
and `vectorbtpro.utils.selection.LabelSel` holding label(s)."""
closed_start: bool = define.optional_field()
"""Whether slice start should be inclusive."""
closed_end: bool = define.optional_field()
"""Whether slice end should be inclusive."""
indexer_method: tp.Optional[str] = define.optional_field()
"""Method for `pd.Index.get_indexer`."""
below_to_zero: bool = define.optional_field()
"""Whether to place 0 instead of -1 if `AutoIdxr.value` is below the first index."""
above_to_len: bool = define.optional_field()
"""Whether to place `len(index)` instead of -1 if `AutoIdxr.value` is above the last index."""
level: tp.MaybeLevelSequence = define.field(default=None)
"""One or more levels.
If `level` is not None and `kind` is None, `kind` becomes "labels"."""
kind: tp.Optional[str] = define.field(default=None)
"""Kind of value.
Allowed are
* "position(s)" for `PosIdxr`
* "mask" for `MaskIdxr`
* "label(s)" for `LabelIdxr`
* "datetime" for `DatetimeIdxr`
* "dtc": for `DTCIdxr`
* "frequency" for `PointIdxr`
If None, will (try to) determine automatically based on the type of indices."""
idxr_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to the selected indexer."""
def __init__(self, *args, **kwargs) -> None:
idxr_kwargs = kwargs.pop("idxr_kwargs", None)
if idxr_kwargs is None:
idxr_kwargs = {}
else:
idxr_kwargs = dict(idxr_kwargs)
builtin_keys = {a.name for a in self.fields}
for k in list(kwargs.keys()):
if k not in builtin_keys:
idxr_kwargs[k] = kwargs.pop(k)
DefineMixin.__init__(self, *args, idxr_kwargs=idxr_kwargs, **kwargs)
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
value = self.value
kind = self.kind
if self.level is not None:
from vectorbtpro.base.indexes import select_levels
if index is None:
raise ValueError("Index is required")
index = select_levels(index, self.level)
if kind is None:
kind = "labels"
if self.idxr_kwargs is None:
idxr_kwargs = self.idxr_kwargs
else:
idxr_kwargs = None
if idxr_kwargs is None:
idxr_kwargs = {}
def _dtc_check_func(dtc):
return (
not dtc.has_full_datetime()
and self.indexer_method in (MISSING, None)
and self.below_to_zero is MISSING
and self.above_to_len is MISSING
)
if kind is None:
if isinstance(value, PosSel):
kind = "positions"
value = value.value
elif isinstance(value, LabelSel):
kind = "labels"
value = value.value
elif isinstance(value, (slice, hslice)):
if checks.is_int(value.start) or checks.is_int(value.stop):
kind = "positions"
elif value.start is None and value.stop is None:
kind = "positions"
else:
if index is None:
raise ValueError("Index is required")
if isinstance(index, pd.DatetimeIndex):
if dt.DTC.is_parsable(value.start, check_func=_dtc_check_func) or dt.DTC.is_parsable(
value.stop, check_func=_dtc_check_func
):
kind = "dtc"
else:
kind = "datetime"
else:
kind = "labels"
elif (checks.is_sequence(value) and not np.isscalar(value)) and (
index is None
or (
not isinstance(index, pd.MultiIndex)
or (isinstance(index, pd.MultiIndex) and isinstance(value[0], tuple))
)
):
if checks.is_bool(value[0]):
kind = "mask"
elif checks.is_int(value[0]):
kind = "positions"
elif (
(index is None or not isinstance(index, pd.MultiIndex) or not isinstance(value[0], tuple))
and checks.is_sequence(value[0])
and len(value[0]) == 2
and checks.is_int(value[0][0])
and checks.is_int(value[0][1])
):
kind = "positions"
else:
if index is None:
raise ValueError("Index is required")
elif isinstance(index, pd.DatetimeIndex):
if dt.DTC.is_parsable(value[0], check_func=_dtc_check_func):
kind = "dtc"
else:
kind = "datetime"
else:
kind = "labels"
else:
if checks.is_bool(value):
kind = "mask"
elif checks.is_int(value):
kind = "positions"
else:
if index is None:
raise ValueError("Index is required")
if isinstance(index, pd.DatetimeIndex):
if dt.DTC.is_parsable(value, check_func=_dtc_check_func):
kind = "dtc"
elif isinstance(value, str):
try:
if not value.isupper() and not value.islower():
raise Exception # "2020" shouldn't be a frequency
_ = dt.to_freq(value)
kind = "frequency"
except Exception as e:
try:
_ = dt.to_timestamp(value)
kind = "datetime"
except Exception as e:
raise ValueError(f"'{value}' is neither a frequency nor a datetime")
elif checks.is_frequency(value):
kind = "frequency"
else:
kind = "datetime"
else:
kind = "labels"
def _expand_target_kwargs(target_cls, **target_kwargs):
source_arg_names = {a.name for a in self.fields if a.default is MISSING}
target_arg_names = {a.name for a in target_cls.fields}
for arg_name in source_arg_names:
if arg_name in target_arg_names:
arg_value = getattr(self, arg_name)
if arg_value is not MISSING:
target_kwargs[arg_name] = arg_value
return target_kwargs
if kind.lower() in ("position", "positions"):
idx = PosIdxr(value, **_expand_target_kwargs(PosIdxr, **idxr_kwargs))
elif kind.lower() == "mask":
idx = MaskIdxr(value, **_expand_target_kwargs(MaskIdxr, **idxr_kwargs))
elif kind.lower() in ("label", "labels"):
idx = LabelIdxr(value, **_expand_target_kwargs(LabelIdxr, **idxr_kwargs))
elif kind.lower() == "datetime":
idx = DatetimeIdxr(value, **_expand_target_kwargs(DatetimeIdxr, **idxr_kwargs))
elif kind.lower() == "dtc":
idx = DTCIdxr(value, **_expand_target_kwargs(DTCIdxr, **idxr_kwargs))
elif kind.lower() == "frequency":
idx = PointIdxr(every=value, **_expand_target_kwargs(PointIdxr, **idxr_kwargs))
else:
raise ValueError(f"Invalid kind: '{kind}'")
return idx.get(index=index, freq=freq)
@define
class RowIdxr(IdxrBase, DefineMixin):
"""Class for resolving row indices."""
idxr: object = define.field()
"""Indexer.
Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`."""
idxr_kwargs: tp.KwargsLike = define.field()
"""Keyword arguments passed to `AutoIdxr`."""
def __init__(self, idxr: object, **idxr_kwargs) -> None:
DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs))
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
) -> tp.MaybeIndexArray:
idxr = self.idxr
if isinstance(idxr, CustomTemplate):
_template_context = merge_dicts(dict(index=index, freq=freq), template_context)
idxr = idxr.substitute(_template_context, eval_id="idxr")
if not isinstance(idxr, UniIdxr):
if isinstance(idxr, IdxrBase):
raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr")
idxr = AutoIdxr(idxr, **self.idxr_kwargs)
return idxr.get(index=index, freq=freq)
@define
class ColIdxr(IdxrBase, DefineMixin):
"""Class for resolving column indices."""
idxr: object = define.field()
"""Indexer.
Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`."""
idxr_kwargs: tp.KwargsLike = define.field()
"""Keyword arguments passed to `AutoIdxr`."""
def __init__(self, idxr: object, **idxr_kwargs) -> None:
DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs))
def get(
self,
columns: tp.Optional[tp.Index] = None,
template_context: tp.KwargsLike = None,
) -> tp.MaybeIndexArray:
idxr = self.idxr
if isinstance(idxr, CustomTemplate):
_template_context = merge_dicts(dict(columns=columns), template_context)
idxr = idxr.substitute(_template_context, eval_id="idxr")
if not isinstance(idxr, UniIdxr):
if isinstance(idxr, IdxrBase):
raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr")
idxr = AutoIdxr(idxr, **self.idxr_kwargs)
return idxr.get(index=columns)
@define
class Idxr(IdxrBase, DefineMixin):
"""Class for resolving indices."""
idxrs: tp.Tuple[object, ...] = define.field()
"""A tuple of one or more indexers.
If one indexer is provided, can be an instance of `RowIdxr` or `ColIdxr`,
a custom template, or a value to wrapped with `RowIdxr`.
If two indexers are provided, can be an instance of `RowIdxr` and `ColIdxr` respectively,
or a value to wrapped with `RowIdxr` and `ColIdxr` respectively."""
idxr_kwargs: tp.KwargsLike = define.field()
"""Keyword arguments passed to `RowIdxr` and `ColIdxr`."""
def __init__(self, *idxrs: object, **idxr_kwargs) -> None:
DefineMixin.__init__(self, idxrs=idxrs, idxr_kwargs=hdict(idxr_kwargs))
def get(
self,
index: tp.Optional[tp.Index] = None,
columns: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]:
if len(self.idxrs) == 0:
raise ValueError("Must provide at least one indexer")
elif len(self.idxrs) == 1:
idxr = self.idxrs[0]
if isinstance(idxr, CustomTemplate):
_template_context = merge_dicts(dict(index=index, columns=columns, freq=freq), template_context)
idxr = idxr.substitute(_template_context, eval_id="idxr")
if isinstance(idxr, tuple):
return type(self)(*idxr).get(
index=index,
columns=columns,
freq=freq,
template_context=template_context,
)
return type(self)(idxr).get(
index=index,
columns=columns,
freq=freq,
template_context=template_context,
)
if isinstance(idxr, ColIdxr):
row_idxr = None
col_idxr = idxr
else:
row_idxr = idxr
col_idxr = None
elif len(self.idxrs) == 2:
row_idxr = self.idxrs[0]
col_idxr = self.idxrs[1]
else:
raise ValueError("Must provide at most two indexers")
if not isinstance(row_idxr, RowIdxr):
if isinstance(row_idxr, (ColIdxr, Idxr)):
raise TypeError(f"Indexer {type(row_idxr)} not supported as a row indexer")
row_idxr = RowIdxr(row_idxr, **self.idxr_kwargs)
row_idxs = row_idxr.get(index=index, freq=freq, template_context=template_context)
if not isinstance(col_idxr, ColIdxr):
if isinstance(col_idxr, (RowIdxr, Idxr)):
raise TypeError(f"Indexer {type(col_idxr)} not supported as a column indexer")
col_idxr = ColIdxr(col_idxr, **self.idxr_kwargs)
col_idxs = col_idxr.get(columns=columns, template_context=template_context)
return row_idxs, col_idxs
def get_idxs(
idxr: object,
index: tp.Optional[tp.Index] = None,
columns: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]:
"""Translate indexer to row and column indices.
If `idxr` is not an indexer class, wraps it with `Idxr`.
Keyword arguments are passed when constructing a new `Idxr`."""
if not isinstance(idxr, Idxr):
idxr = Idxr(idxr, **kwargs)
return idxr.get(index=index, columns=columns, freq=freq, template_context=template_context)
class index_dict(pdict):
"""Dict that contains indexer objects as keys and values to be set as values.
Each indexer object must be hashable. To make a slice hashable, use `hslice`.
To make an array hashable, convert it into a tuple.
To set a default value, use the `_def` key (case-sensitive!)."""
pass
IdxSetterT = tp.TypeVar("IdxSetterT", bound="IdxSetter")
@define
class IdxSetter(DefineMixin):
"""Class for setting values based on indexing."""
idx_items: tp.List[tp.Tuple[object, tp.ArrayLike]] = define.field()
"""Items where the first element is an indexer and the second element is a value to be set."""
@classmethod
def set_row_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None:
"""Set row indices in an array."""
from vectorbtpro.base.reshaping import broadcast_array_to
if not isinstance(v, np.ndarray):
v = np.asarray(v)
single_v = v.size == 1 or (v.ndim == 2 and v.shape[0] == 1)
if arr.ndim == 2:
single_row = not isinstance(idxs, slice) and (np.isscalar(idxs) or idxs.size == 1)
if not single_row:
if v.ndim == 1 and v.size > 1:
v = v[:, None]
if isinstance(idxs, np.ndarray) and idxs.ndim == 2:
if not single_v:
if arr.ndim == 2:
v = broadcast_array_to(v, (len(idxs), arr.shape[1]))
else:
v = broadcast_array_to(v, (len(idxs),))
for i in range(len(idxs)):
_slice = slice(idxs[i, 0], idxs[i, 1])
if not single_v:
cls.set_row_idxs(arr, _slice, v[[i]])
else:
cls.set_row_idxs(arr, _slice, v)
else:
arr[idxs] = v
@classmethod
def set_col_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None:
"""Set column indices in an array."""
from vectorbtpro.base.reshaping import broadcast_array_to
if not isinstance(v, np.ndarray):
v = np.asarray(v)
single_v = v.size == 1 or (v.ndim == 2 and v.shape[1] == 1)
if isinstance(idxs, np.ndarray) and idxs.ndim == 2:
if not single_v:
v = broadcast_array_to(v, (arr.shape[0], len(idxs)))
for j in range(len(idxs)):
_slice = slice(idxs[j, 0], idxs[j, 1])
if not single_v:
cls.set_col_idxs(arr, _slice, v[:, [j]])
else:
cls.set_col_idxs(arr, _slice, v)
else:
arr[:, idxs] = v
@classmethod
def set_row_and_col_idxs(
cls,
arr: tp.Array,
row_idxs: tp.MaybeIndexArray,
col_idxs: tp.MaybeIndexArray,
v: tp.Any,
) -> None:
"""Set row and column indices in an array."""
from vectorbtpro.base.reshaping import broadcast_array_to
if not isinstance(v, np.ndarray):
v = np.asarray(v)
single_v = v.size == 1
if (
isinstance(row_idxs, np.ndarray)
and row_idxs.ndim == 2
and isinstance(col_idxs, np.ndarray)
and col_idxs.ndim == 2
):
if not single_v:
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
for i in range(len(row_idxs)):
for j in range(len(col_idxs)):
row_slice = slice(row_idxs[i, 0], row_idxs[i, 1])
col_slice = slice(col_idxs[j, 0], col_idxs[j, 1])
if not single_v:
cls.set_row_and_col_idxs(arr, row_slice, col_slice, v[i, j])
else:
cls.set_row_and_col_idxs(arr, row_slice, col_slice, v)
elif isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2:
if not single_v:
if isinstance(col_idxs, slice):
col_idxs = np.arange(arr.shape[1])[col_idxs]
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
for i in range(len(row_idxs)):
row_slice = slice(row_idxs[i, 0], row_idxs[i, 1])
if not single_v:
cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v[[i]])
else:
cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v)
elif isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2:
if not single_v:
if isinstance(row_idxs, slice):
row_idxs = np.arange(arr.shape[0])[row_idxs]
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
for j in range(len(col_idxs)):
col_slice = slice(col_idxs[j, 0], col_idxs[j, 1])
if not single_v:
cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v[:, [j]])
else:
cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v)
else:
if np.isscalar(row_idxs) or np.isscalar(col_idxs):
arr[row_idxs, col_idxs] = v
elif np.isscalar(v) and (isinstance(row_idxs, slice) or isinstance(col_idxs, slice)):
arr[row_idxs, col_idxs] = v
elif np.isscalar(v):
arr[np.ix_(row_idxs, col_idxs)] = v
else:
if isinstance(row_idxs, slice):
row_idxs = np.arange(arr.shape[0])[row_idxs]
if isinstance(col_idxs, slice):
col_idxs = np.arange(arr.shape[1])[col_idxs]
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
arr[np.ix_(row_idxs, col_idxs)] = v
def get_set_meta(
self,
shape: tp.ShapeLike,
index: tp.Optional[tp.Index] = None,
columns: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
) -> tp.Kwargs:
"""Get meta of setting operations in `IdxSetter.idx_items`."""
from vectorbtpro.base.reshaping import to_tuple_shape
shape = to_tuple_shape(shape)
rows_changed = False
cols_changed = False
set_funcs = []
default = None
for idxr, v in self.idx_items:
if isinstance(idxr, str) and idxr == "_def":
if default is None:
default = v
continue
row_idxs, col_idxs = get_idxs(
idxr,
index=index,
columns=columns,
freq=freq,
template_context=template_context,
)
if isinstance(v, CustomTemplate):
_template_context = merge_dicts(
dict(
idxr=idxr,
row_idxs=row_idxs,
col_idxs=col_idxs,
),
template_context,
)
v = v.substitute(_template_context, eval_id="set")
if not isinstance(v, np.ndarray):
v = np.asarray(v)
def _check_use_idxs(idxs):
use_idxs = True
if isinstance(idxs, slice):
if idxs.start is None and idxs.stop is None and idxs.step is None:
use_idxs = False
if isinstance(idxs, np.ndarray):
if idxs.size == 0:
use_idxs = False
return use_idxs
use_row_idxs = _check_use_idxs(row_idxs)
use_col_idxs = _check_use_idxs(col_idxs)
if use_row_idxs and use_col_idxs:
set_funcs.append(partial(self.set_row_and_col_idxs, row_idxs=row_idxs, col_idxs=col_idxs, v=v))
rows_changed = True
cols_changed = True
elif use_col_idxs:
set_funcs.append(partial(self.set_col_idxs, idxs=col_idxs, v=v))
if checks.is_int(col_idxs):
if v.size > 1:
rows_changed = True
else:
if v.ndim == 2:
if v.shape[0] > 1:
rows_changed = True
cols_changed = True
else:
set_funcs.append(partial(self.set_row_idxs, idxs=row_idxs, v=v))
if use_row_idxs:
rows_changed = True
if len(shape) == 2:
if checks.is_int(row_idxs):
if v.size > 1:
cols_changed = True
else:
if v.ndim == 2:
if v.shape[1] > 1:
cols_changed = True
return dict(
default=default,
set_funcs=set_funcs,
rows_changed=rows_changed,
cols_changed=cols_changed,
)
def set(self, arr: tp.Array, set_funcs: tp.Optional[tp.Sequence[tp.Callable]] = None, **kwargs) -> None:
"""Set values of a NumPy array based on `IdxSetter.get_set_meta`."""
if set_funcs is None:
set_meta = self.get_set_meta(arr.shape, **kwargs)
set_funcs = set_meta["set_funcs"]
for set_op in set_funcs:
set_op(arr)
def set_pd(self, pd_arr: tp.SeriesFrame, **kwargs) -> None:
"""Set values of a Pandas array based on `IdxSetter.get_set_meta`."""
from vectorbtpro.base.indexes import get_index
index = get_index(pd_arr, 0)
columns = get_index(pd_arr, 1)
freq = dt.infer_index_freq(index)
self.set(pd_arr.values, index=index, columns=columns, freq=freq, **kwargs)
def fill_and_set(
self,
shape: tp.ShapeLike,
keep_flex: bool = False,
fill_value: tp.Scalar = np.nan,
**kwargs,
) -> tp.Array:
"""Fill a new array and set its values based on `IdxSetter.get_set_meta`.
If `keep_flex` is True, will return the most memory-efficient array representation
capable of flexible indexing.
If `fill_value` is None, will search for the `_def` key in `IdxSetter.idx_items`.
If there's none, will be set to NaN."""
set_meta = self.get_set_meta(shape, **kwargs)
if set_meta["default"] is not None:
fill_value = set_meta["default"]
if isinstance(fill_value, str):
dtype = object
else:
dtype = None
if keep_flex and not set_meta["cols_changed"] and not set_meta["rows_changed"]:
arr = np.full((1,) if len(shape) == 1 else (1, 1), fill_value, dtype=dtype)
elif keep_flex and not set_meta["cols_changed"]:
arr = np.full(shape if len(shape) == 1 else (shape[0], 1), fill_value, dtype=dtype)
elif keep_flex and not set_meta["rows_changed"]:
arr = np.full((1, shape[1]), fill_value, dtype=dtype)
else:
arr = np.full(shape, fill_value, dtype=dtype)
self.set(arr, set_funcs=set_meta["set_funcs"])
return arr
class IdxSetterFactory(Base):
"""Class for building index setters."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
"""Get an instance of `IdxSetter` or a dict of such instances - one per array name."""
raise NotImplementedError
@define
class IdxDict(IdxSetterFactory, DefineMixin):
"""Class for building an index setter from a dict."""
index_dct: dict = define.field()
"""Dict that contains indexer objects as keys and values to be set as values."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
return IdxSetter(list(self.index_dct.items()))
@define
class IdxSeries(IdxSetterFactory, DefineMixin):
"""Class for building an index setter from a Series."""
sr: tp.AnyArray1d = define.field()
"""Series or any array-like object to create the Series from."""
split: bool = define.field(default=False)
"""Whether to split the setting operation.
If False, will set all values using a single operation.
Otherwise, will do one operation per element."""
idx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `idx` if the indexer isn't an instance of `Idxr`."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
sr = self.sr
split = self.split
idx_kwargs = self.idx_kwargs
if idx_kwargs is None:
idx_kwargs = {}
if not isinstance(sr, pd.Series):
sr = pd.Series(sr)
if split:
idx_items = list(sr.items())
else:
idx_items = [(sr.index, sr.values)]
new_idx_items = []
for idxr, v in idx_items:
if idxr is None:
raise ValueError("Indexer cannot be None")
if not isinstance(idxr, Idxr):
idxr = idx(idxr, **idx_kwargs)
new_idx_items.append((idxr, v))
return IdxSetter(new_idx_items)
@define
class IdxFrame(IdxSetterFactory, DefineMixin):
"""Class for building an index setter from a DataFrame."""
df: tp.AnyArray2d = define.field()
"""DataFrame or any array-like object to create the DataFrame from."""
split: tp.Union[bool, str] = define.field(default=False)
"""Whether to split the setting operation.
If False, will set all values using a single operation.
Otherwise, the following options are supported:
* 'columns': one operation per column
* 'rows': one operation per row
* True or 'elements': one operation per element"""
rowidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`."""
colidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
df = self.df
split = self.split
rowidx_kwargs = self.rowidx_kwargs
colidx_kwargs = self.colidx_kwargs
if rowidx_kwargs is None:
rowidx_kwargs = {}
if colidx_kwargs is None:
colidx_kwargs = {}
if not isinstance(df, pd.DataFrame):
df = pd.DataFrame(df)
if isinstance(split, bool):
if split:
split = "elements"
else:
split = None
if split is not None:
if split.lower() == "columns":
idx_items = []
for col, sr in df.items():
idx_items.append((sr.index, col, sr.values))
elif split.lower() == "rows":
idx_items = []
for row, sr in df.iterrows():
idx_items.append((row, df.columns, sr.values))
elif split.lower() == "elements":
idx_items = []
for col, sr in df.items():
for row, v in sr.items():
idx_items.append((row, col, v))
else:
raise ValueError(f"Invalid split: '{split}'")
else:
idx_items = [(df.index, df.columns, df.values)]
new_idx_items = []
for row_idxr, col_idxr, v in idx_items:
if row_idxr is None:
raise ValueError("Row indexer cannot be None")
if col_idxr is None:
raise ValueError("Column indexer cannot be None")
if row_idxr is not None and not isinstance(row_idxr, RowIdxr):
row_idxr = rowidx(row_idxr, **rowidx_kwargs)
if col_idxr is not None and not isinstance(col_idxr, ColIdxr):
col_idxr = colidx(col_idxr, **colidx_kwargs)
new_idx_items.append((idx(row_idxr, col_idxr), v))
return IdxSetter(new_idx_items)
@define
class IdxRecords(IdxSetterFactory, DefineMixin):
"""Class for building index setters from records - one per field."""
records: tp.RecordsLike = define.field()
"""Series, DataFrame, or any sequence of mapping-like objects.
If a Series or DataFrame and the index is not a default range, the index will become a row field.
If a custom row field is provided, the index will be ignored."""
row_field: tp.Union[None, bool, tp.Label] = define.field(default=None)
"""Row field.
If None or True, will search for "row", "index", "open time", and "date" (case-insensitive).
If `IdxRecords.records` is a Series or DataFrame, will also include the index name
if the index is not a default range.
If a record doesn't have a row field, all rows will be set.
If there's no row and column field, the field value will become the default of the entire array."""
col_field: tp.Union[None, bool, tp.Label] = define.field(default=None)
"""Column field.
If None or True, will search for "col", "column", and "symbol" (case-insensitive).
If a record doesn't have a column field, all columns will be set.
If there's no row and column field, the field value will become the default of the entire array."""
rowidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`."""
colidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
records = self.records
row_field = self.row_field
col_field = self.col_field
rowidx_kwargs = self.rowidx_kwargs
colidx_kwargs = self.colidx_kwargs
if rowidx_kwargs is None:
rowidx_kwargs = {}
if colidx_kwargs is None:
colidx_kwargs = {}
default_index = False
index_field = None
if isinstance(records, pd.Series):
records = records.to_frame()
if isinstance(records, pd.DataFrame):
records = records
if checks.is_default_index(records.index):
default_index = True
records = records.reset_index(drop=default_index)
if not default_index:
index_field = records.columns[0]
records = records.itertuples(index=False)
def _resolve_field_meta(fields):
_row_field = row_field
_row_kind = None
_col_field = col_field
_col_kind = None
row_fields = set()
col_fields = set()
for field in fields:
if isinstance(field, str) and index_field is not None and field == index_field:
row_fields.add((field, None))
if isinstance(field, str) and field.lower() in ("row", "index"):
row_fields.add((field, None))
if isinstance(field, str) and field.lower() in ("open time", "date", "datetime"):
if (field, None) in row_fields:
row_fields.remove((field, None))
row_fields.add((field, "datetime"))
if isinstance(field, str) and field.lower() in ("col", "column"):
col_fields.add((field, None))
if isinstance(field, str) and field.lower() == "symbol":
if (field, None) in col_fields:
col_fields.remove((field, None))
col_fields.add((field, "labels"))
if _row_field in (None, True):
if len(row_fields) == 0:
if _row_field is True:
raise ValueError("Cannot find row field")
_row_field = None
elif len(row_fields) == 1:
_row_field, _row_kind = row_fields.pop()
else:
raise ValueError("Multiple row field candidates")
elif _row_field is False:
_row_field = None
if _col_field in (None, True):
if len(col_fields) == 0:
if _col_field is True:
raise ValueError("Cannot find column field")
_col_field = None
elif len(col_fields) == 1:
_col_field, _col_kind = col_fields.pop()
else:
raise ValueError("Multiple column field candidates")
elif _col_field is False:
_col_field = None
field_meta = dict()
field_meta["row_field"] = _row_field
field_meta["row_kind"] = _row_kind
field_meta["col_field"] = _col_field
field_meta["col_kind"] = _col_kind
return field_meta
idx_items = dict()
for r in records:
r = to_field_mapping(r)
field_meta = _resolve_field_meta(r.keys())
if field_meta["row_field"] is None:
row_idxr = None
else:
row_idxr = r.get(field_meta["row_field"], None)
if row_idxr == "_def":
row_idxr = None
if row_idxr is not None and not isinstance(row_idxr, RowIdxr):
_rowidx_kwargs = dict(rowidx_kwargs)
if field_meta["row_kind"] is not None and "kind" not in _rowidx_kwargs:
_rowidx_kwargs["kind"] = field_meta["row_kind"]
row_idxr = rowidx(row_idxr, **_rowidx_kwargs)
if field_meta["col_field"] is None:
col_idxr = None
else:
col_idxr = r.get(field_meta["col_field"], None)
if col_idxr is not None and not isinstance(col_idxr, ColIdxr):
_colidx_kwargs = dict(colidx_kwargs)
if field_meta["col_kind"] is not None and "kind" not in _colidx_kwargs:
_colidx_kwargs["kind"] = field_meta["col_kind"]
col_idxr = colidx(col_idxr, **_colidx_kwargs)
if isinstance(col_idxr, str) and col_idxr == "_def":
col_idxr = None
item_produced = False
for k, v in r.items():
if index_field is not None and k == index_field:
continue
if field_meta["row_field"] is not None and k == field_meta["row_field"]:
continue
if field_meta["col_field"] is not None and k == field_meta["col_field"]:
continue
if k not in idx_items:
idx_items[k] = []
if row_idxr is None and col_idxr is None:
idx_items[k].append(("_def", v))
else:
idx_items[k].append((idx(row_idxr, col_idxr), v))
item_produced = True
if not item_produced:
raise ValueError(f"Record {r} has no fields to set")
idx_setters = dict()
for k, v in idx_items.items():
idx_setters[k] = IdxSetter(v)
return idx_setters
posidx = PosIdxr
"""Shortcut for `PosIdxr`."""
__pdoc__["posidx"] = False
maskidx = MaskIdxr
"""Shortcut for `MaskIdxr`."""
__pdoc__["maskidx"] = False
lbidx = LabelIdxr
"""Shortcut for `LabelIdxr`."""
__pdoc__["lbidx"] = False
dtidx = DatetimeIdxr
"""Shortcut for `DatetimeIdxr`."""
__pdoc__["dtidx"] = False
dtcidx = DTCIdxr
"""Shortcut for `DTCIdxr`."""
__pdoc__["dtcidx"] = False
pointidx = PointIdxr
"""Shortcut for `PointIdxr`."""
__pdoc__["pointidx"] = False
rangeidx = RangeIdxr
"""Shortcut for `RangeIdxr`."""
__pdoc__["rangeidx"] = False
autoidx = AutoIdxr
"""Shortcut for `AutoIdxr`."""
__pdoc__["autoidx"] = False
rowidx = RowIdxr
"""Shortcut for `RowIdxr`."""
__pdoc__["rowidx"] = False
colidx = ColIdxr
"""Shortcut for `ColIdxr`."""
__pdoc__["colidx"] = False
idx = Idxr
"""Shortcut for `Idxr`."""
__pdoc__["idx"] = False
</file>
<file path="base/merging.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for merging arrays."""
from functools import partial
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import stack_indexes, concat_indexes, clean_index
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import resolve_dict, merge_dicts, HybridConfig
from vectorbtpro.utils.execution import NoResult, NoResultsException, filter_out_no_results
from vectorbtpro.utils.merging import MergeFunc
__all__ = [
"concat_arrays",
"row_stack_arrays",
"column_stack_arrays",
"concat_merge",
"row_stack_merge",
"column_stack_merge",
"imageio_merge",
"mixed_merge",
]
__pdoc__ = {}
def concat_arrays(*arrs: tp.MaybeSequence[tp.AnyArray]) -> tp.Array1d:
"""Concatenate arrays."""
if len(arrs) == 1:
arrs = arrs[0]
arrs = list(arrs)
arrs = list(map(to_1d_array, arrs))
return np.concatenate(arrs)
def row_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d:
"""Stack arrays along rows."""
if len(arrs) == 1:
arrs = arrs[0]
arrs = list(arrs)
arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs))
return np.concatenate(arrs, axis=0)
def column_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d:
"""Stack arrays along columns."""
if len(arrs) == 1:
arrs = arrs[0]
arrs = list(arrs)
arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs))
common_shape = None
can_concatenate = True
for arr in arrs:
if common_shape is None:
common_shape = arr.shape
if arr.shape != common_shape:
can_concatenate = False
continue
if not (arr.ndim == 1 or (arr.ndim == 2 and arr.shape[1] == 1)):
can_concatenate = False
continue
if can_concatenate:
return np.concatenate(arrs, axis=0).reshape((len(arrs), common_shape[0])).T
return np.concatenate(arrs, axis=1)
def concat_merge(
*objs,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
wrap: tp.Optional[bool] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLikeSequence = None,
clean_index_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge multiple array-like objects through concatenation.
Supports a sequence of tuples.
If `wrap` is None, it will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None.
If `wrap` is True, each array will be wrapped with Pandas Series and merged using `pd.concat`.
Otherwise, arrays will be kept as-is and merged using `concat_arrays`.
`wrap_kwargs` can be a dictionary or a list of dictionaries.
If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap_reduced`.
Keyword arguments `**kwargs` are passed to `pd.concat` only.
!!! note
All arrays are assumed to have the same type and dimensionality."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
concat_merge(
list(map(lambda x: x[0], objs)),
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: concat_merge(
x,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if isinstance(objs[0], Wrapping):
raise TypeError("Concatenating Wrapping instances is not supported")
if wrap_kwargs is None:
wrap_kwargs = {}
if wrap is None:
wrap = isinstance(objs[0], pd.Series) or wrapper is not None or keys is not None or len(wrap_kwargs) > 0
if not checks.is_complex_iterable(objs[0]):
if wrap:
if keys is not None and isinstance(keys[0], pd.Index):
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
clean_index_kwargs=clean_index_kwargs,
verify_integrity=False,
axis=0,
)
wrap_kwargs = merge_dicts(dict(index=keys), wrap_kwargs)
return pd.Series(objs, **wrap_kwargs)
return np.asarray(objs)
if isinstance(objs[0], pd.Index):
objs = list(map(lambda x: x.to_series(), objs))
default_index = True
if not isinstance(objs[0], pd.Series):
if isinstance(objs[0], pd.DataFrame):
raise ValueError("Use row stacking for concatenating DataFrames")
if wrap:
new_objs = []
for i, obj in enumerate(objs):
_wrap_kwargs = resolve_dict(wrap_kwargs, i)
if wrapper is not None:
if "force_1d" not in _wrap_kwargs:
_wrap_kwargs["force_1d"] = True
new_objs.append(wrapper.wrap_reduced(obj, **_wrap_kwargs))
else:
new_objs.append(pd.Series(obj, **_wrap_kwargs))
if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True):
default_index = False
objs = new_objs
if not wrap:
return concat_arrays(objs)
if keys is not None and isinstance(keys[0], pd.Index):
new_obj = pd.concat(objs, axis=0, **kwargs)
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
verify_integrity=False,
axis=0,
)
if default_index:
new_obj.index = keys
else:
new_obj.index = stack_indexes((keys, new_obj.index))
else:
new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_obj.index = clean_index(new_obj.index, **clean_index_kwargs)
return new_obj
def row_stack_merge(
*objs,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
wrap: tp.Union[None, str, bool] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLikeSequence = None,
clean_index_kwargs: tp.KwargsLikeSequence = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through row stacking.
Supports a sequence of tuples.
Argument `wrap` supports the following options:
* None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None
* True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions)
* 'sr', 'series': each array will be wrapped with Pandas Series
* 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame
Without wrapping, arrays will be kept as-is and merged using `row_stack_arrays`.
Argument `wrap_kwargs` can be a dictionary or a list of dictionaries.
If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
Keyword arguments `**kwargs` are passed to `pd.concat` and
`vectorbtpro.base.wrapping.Wrapping.row_stack` only.
!!! note
All arrays are assumed to have the same type and dimensionality."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
row_stack_merge(
list(map(lambda x: x[0], objs)),
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: row_stack_merge(
x,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if isinstance(objs[0], Wrapping):
kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs)
return type(objs[0]).row_stack(objs, **kwargs)
if wrap_kwargs is None:
wrap_kwargs = {}
if wrap is None:
wrap = (
isinstance(objs[0], (pd.Series, pd.DataFrame))
or wrapper is not None
or keys is not None
or len(wrap_kwargs) > 0
)
if isinstance(objs[0], pd.Index):
objs = list(map(lambda x: x.to_series(), objs))
default_index = True
if not isinstance(objs[0], (pd.Series, pd.DataFrame)):
if isinstance(wrap, str) or wrap:
new_objs = []
for i, obj in enumerate(objs):
_wrap_kwargs = resolve_dict(wrap_kwargs, i)
if wrapper is not None:
new_objs.append(wrapper.wrap(obj, **_wrap_kwargs))
else:
if not isinstance(wrap, str):
if isinstance(obj, np.ndarray):
ndim = obj.ndim
else:
ndim = np.asarray(obj).ndim
if ndim == 1:
wrap = "series"
else:
wrap = "frame"
if isinstance(wrap, str):
if wrap.lower() in ("sr", "series"):
new_objs.append(pd.Series(obj, **_wrap_kwargs))
elif wrap.lower() in ("df", "frame", "dataframe"):
new_objs.append(pd.DataFrame(obj, **_wrap_kwargs))
else:
raise ValueError(f"Invalid wrapping option: '{wrap}'")
if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True):
default_index = False
objs = new_objs
if not wrap:
return row_stack_arrays(objs)
if keys is not None and isinstance(keys[0], pd.Index):
new_obj = pd.concat(objs, axis=0, **kwargs)
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
verify_integrity=False,
axis=0,
)
if default_index:
new_obj.index = keys
else:
new_obj.index = stack_indexes((keys, new_obj.index))
else:
new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_obj.index = clean_index(new_obj.index, **clean_index_kwargs)
return new_obj
def column_stack_merge(
*objs,
reset_index: tp.Union[None, bool, str] = None,
fill_value: tp.Scalar = np.nan,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
wrap: tp.Union[None, str, bool] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLikeSequence = None,
clean_index_kwargs: tp.KwargsLikeSequence = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through column stacking.
Supports a sequence of tuples.
Argument `wrap` supports the following options:
* None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None
* True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions)
* 'sr', 'series': each array will be wrapped with Pandas Series
* 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame
Without wrapping, arrays will be kept as-is and merged using `column_stack_arrays`.
Argument `wrap_kwargs` can be a dictionary or a list of dictionaries.
If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
Keyword arguments `**kwargs` are passed to `pd.concat` and
`vectorbtpro.base.wrapping.Wrapping.column_stack` only.
Argument `reset_index` supports the following options:
* False or None: Keep original index of each object
* True or 'from_start': Reset index of each object and align them at start
* 'from_end': Reset index of each object and align them at end
Options above work on Pandas, NumPy, and `vectorbtpro.base.wrapping.Wrapping` instances.
!!! note
All arrays are assumed to have the same type and dimensionality."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(reset_index, bool):
if reset_index:
reset_index = "from_start"
else:
reset_index = None
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
column_stack_merge(
list(map(lambda x: x[0], objs)),
reset_index=reset_index,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: column_stack_merge(
x,
reset_index=reset_index,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if isinstance(objs[0], Wrapping):
if reset_index is not None:
max_length = max(map(lambda x: x.wrapper.shape[0], objs))
new_objs = []
for obj in objs:
if isinstance(reset_index, str) and reset_index.lower() == "from_start":
new_index = pd.RangeIndex(stop=obj.wrapper.shape[0])
new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index))
elif isinstance(reset_index, str) and reset_index.lower() == "from_end":
new_index = pd.RangeIndex(start=max_length - obj.wrapper.shape[0], stop=max_length)
new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index))
else:
raise ValueError(f"Invalid index resetting option: '{reset_index}'")
new_objs.append(new_obj)
objs = new_objs
kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs)
return type(objs[0]).column_stack(objs, **kwargs)
if wrap_kwargs is None:
wrap_kwargs = {}
if wrap is None:
wrap = (
isinstance(objs[0], (pd.Series, pd.DataFrame))
or wrapper is not None
or keys is not None
or len(wrap_kwargs) > 0
)
if isinstance(objs[0], pd.Index):
objs = list(map(lambda x: x.to_series(), objs))
default_columns = True
if not isinstance(objs[0], (pd.Series, pd.DataFrame)):
if isinstance(wrap, str) or wrap:
new_objs = []
for i, obj in enumerate(objs):
_wrap_kwargs = resolve_dict(wrap_kwargs, i)
if wrapper is not None:
new_objs.append(wrapper.wrap(obj, **_wrap_kwargs))
else:
if not isinstance(wrap, str):
if isinstance(obj, np.ndarray):
ndim = obj.ndim
else:
ndim = np.asarray(obj).ndim
if ndim == 1:
wrap = "series"
else:
wrap = "frame"
if isinstance(wrap, str):
if wrap.lower() in ("sr", "series"):
new_objs.append(pd.Series(obj, **_wrap_kwargs))
elif wrap.lower() in ("df", "frame", "dataframe"):
new_objs.append(pd.DataFrame(obj, **_wrap_kwargs))
else:
raise ValueError(f"Invalid wrapping option: '{wrap}'")
if (
default_columns
and isinstance(new_objs[-1], pd.DataFrame)
and not checks.is_default_index(new_objs[-1].columns, check_names=True)
):
default_columns = False
objs = new_objs
if not wrap:
if reset_index is not None:
min_n_rows = None
max_n_rows = None
n_cols = 0
new_objs = []
for obj in objs:
new_obj = to_2d_array(obj)
new_objs.append(new_obj)
if min_n_rows is None or new_obj.shape[0] < min_n_rows:
min_n_rows = new_obj.shape[0]
if max_n_rows is None or new_obj.shape[0] > min_n_rows:
max_n_rows = new_obj.shape[0]
n_cols += new_obj.shape[1]
if min_n_rows == max_n_rows:
return column_stack_arrays(new_objs)
new_obj = np.full((max_n_rows, n_cols), fill_value)
start_col = 0
for obj in new_objs:
end_col = start_col + obj.shape[1]
if isinstance(reset_index, str) and reset_index.lower() == "from_start":
new_obj[: len(obj), start_col:end_col] = obj
elif isinstance(reset_index, str) and reset_index.lower() == "from_end":
new_obj[-len(obj) :, start_col:end_col] = obj
else:
raise ValueError(f"Invalid index resetting option: '{reset_index}'")
start_col = end_col
return new_obj
return column_stack_arrays(objs)
if reset_index is not None:
max_length = max(map(len, objs))
new_objs = []
for obj in objs:
new_obj = obj.copy(deep=False)
if isinstance(reset_index, str) and reset_index.lower() == "from_start":
new_obj.index = pd.RangeIndex(stop=len(new_obj))
elif isinstance(reset_index, str) and reset_index.lower() == "from_end":
new_obj.index = pd.RangeIndex(start=max_length - len(new_obj), stop=max_length)
else:
raise ValueError(f"Invalid index resetting option: '{reset_index}'")
new_objs.append(new_obj)
objs = new_objs
kwargs = merge_dicts(dict(sort=True), kwargs)
if keys is not None and isinstance(keys[0], pd.Index):
new_obj = pd.concat(objs, axis=1, **kwargs)
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
verify_integrity=False,
axis=1,
)
if default_columns:
new_obj.columns = keys
else:
new_obj.columns = stack_indexes((keys, new_obj.columns))
else:
new_obj = pd.concat(objs, axis=1, keys=keys, **kwargs)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_obj.columns = clean_index(new_obj.columns, **clean_index_kwargs)
return new_obj
def imageio_merge(
*objs,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
to_image_kwargs: tp.KwargsLike = None,
imread_kwargs: tp.KwargsLike = None,
**imwrite_kwargs,
) -> tp.MaybeTuple[tp.Union[None, bytes]]:
"""Merge multiple figure-like objects by writing them with `imageio`.
Keyword arguments `to_image_kwargs` are passed to `plotly.graph_objects.Figure.to_image`.
Keyword arguments `imread_kwargs` and `**imwrite_kwargs` are passed to
`imageio.imread` and `imageio.imwrite` respectively.
Keys are not used in any way."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
import imageio.v3 as iio
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
imageio_merge(
list(map(lambda x: x[0], objs)),
keys=keys,
imread_kwargs=imread_kwargs,
to_image_kwargs=to_image_kwargs,
**imwrite_kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: imageio_merge(
x,
keys=keys,
imread_kwargs=imread_kwargs,
to_image_kwargs=to_image_kwargs,
**imwrite_kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if imread_kwargs is None:
imread_kwargs = {}
if to_image_kwargs is None:
to_image_kwargs = {}
frames = []
for obj in objs:
if obj is None:
continue
if isinstance(obj, (go.Figure, go.FigureWidget)):
obj = obj.to_image(**to_image_kwargs)
if not isinstance(obj, np.ndarray):
obj = iio.imread(obj, **imread_kwargs)
frames.append(obj)
return iio.imwrite(image=frames, **imwrite_kwargs)
def mixed_merge(
*objs,
merge_funcs: tp.Optional[tp.MergeFuncLike] = None,
mixed_kwargs: tp.Optional[tp.Sequence[tp.KwargsLike]] = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge objects of mixed types."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if merge_funcs is None:
raise ValueError("Merging functions or their names are required")
if not isinstance(objs[0], tuple):
raise ValueError("Mixed merging must be applied on tuples")
outputs = []
for i, output_objs in enumerate(zip(*objs)):
output_objs = list(output_objs)
merge_func = resolve_merge_func(merge_funcs[i])
if merge_func is None:
outputs.append(output_objs)
else:
if mixed_kwargs is None:
_kwargs = kwargs
else:
_kwargs = merge_dicts(kwargs, mixed_kwargs[i])
output = merge_func(output_objs, **_kwargs)
outputs.append(output)
return tuple(outputs)
merge_func_config = HybridConfig(
dict(
concat=concat_merge,
row_stack=row_stack_merge,
column_stack=column_stack_merge,
reset_column_stack=partial(column_stack_merge, reset_index=True),
from_start_column_stack=partial(column_stack_merge, reset_index="from_start"),
from_end_column_stack=partial(column_stack_merge, reset_index="from_end"),
imageio=imageio_merge,
)
)
"""_"""
__pdoc__[
"merge_func_config"
] = f"""Config for merging functions.
```python
{merge_func_config.prettify()}
```
"""
def resolve_merge_func(merge_func: tp.MergeFuncLike) -> tp.Optional[tp.Callable]:
"""Resolve a merging function into a callable.
If a string, looks up into `merge_func_config`. If a sequence, uses `mixed_merge` with
`merge_funcs=merge_func`. If an instance of `vectorbtpro.utils.merging.MergeFunc`, calls
`vectorbtpro.utils.merging.MergeFunc.resolve_merge_func` to get the actual callable."""
if merge_func is None:
return None
if isinstance(merge_func, str):
if merge_func.lower() not in merge_func_config:
raise ValueError(f"Invalid merging function name: '{merge_func}'")
return merge_func_config[merge_func.lower()]
if checks.is_sequence(merge_func):
return partial(mixed_merge, merge_funcs=merge_func)
if isinstance(merge_func, MergeFunc):
return merge_func.resolve_merge_func()
return merge_func
def is_merge_func_from_config(merge_func: tp.MergeFuncLike) -> bool:
"""Return whether the merging function can be found in `merge_func_config`."""
if merge_func is None:
return False
if isinstance(merge_func, str):
return merge_func.lower() in merge_func_config
if checks.is_sequence(merge_func):
return all(map(is_merge_func_from_config, merge_func))
if isinstance(merge_func, MergeFunc):
return is_merge_func_from_config(merge_func.merge_func)
return False
</file>
<file path="base/preparing.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for preparing arguments."""
import inspect
import string
from collections import defaultdict
from datetime import timedelta, time
from functools import cached_property as cachedproperty
from pathlib import Path
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.decorators import override_arg_config, attach_arg_properties
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.indexing import index_dict, IdxSetter, IdxSetterFactory, IdxRecords
from vectorbtpro.base.merging import concat_arrays, column_stack_arrays
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import BCO, Default, Ref, broadcast
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.attr_ import get_dict_attr
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig
from vectorbtpro.utils.cutting import suggest_module_path, cut_and_save_func
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.module_ import import_module_from_path
from vectorbtpro.utils.params import Param
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.path_ import remove_dir
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import CustomTemplate, RepFunc, substitute_templates
__all__ = [
"BasePreparer",
]
__pdoc__ = {}
base_arg_config = ReadonlyConfig(
dict(
broadcast_named_args=dict(is_dict=True),
broadcast_kwargs=dict(is_dict=True),
template_context=dict(is_dict=True),
seed=dict(),
jitted=dict(),
chunked=dict(),
staticized=dict(),
records=dict(),
)
)
"""_"""
__pdoc__[
"base_arg_config"
] = f"""Argument config for `BasePreparer`.
```python
{base_arg_config.prettify()}
```
"""
class MetaBasePreparer(type(Configured)):
"""Metaclass for `BasePreparer`."""
@property
def arg_config(cls) -> Config:
"""Argument config."""
return cls._arg_config
@attach_arg_properties
@override_arg_config(base_arg_config)
class BasePreparer(Configured, metaclass=MetaBasePreparer):
"""Base class for preparing target functions and arguments.
!!! warning
Most properties are force-cached - create a new instance to override any attribute."""
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
_writeable_attrs: tp.WriteableAttrs = {"_arg_config"}
_settings_path: tp.SettingsPath = None
def __init__(self, arg_config: tp.KwargsLike = None, **kwargs) -> None:
Configured.__init__(self, arg_config=arg_config, **kwargs)
# Copy writeable attrs
self._arg_config = type(self)._arg_config.copy()
if arg_config is not None:
self._arg_config = merge_dicts(self._arg_config, arg_config)
_arg_config: tp.ClassVar[Config] = HybridConfig()
@property
def arg_config(self) -> Config:
"""Argument config of `${cls_name}`.
```python
${arg_config}
```
"""
return self._arg_config
@classmethod
def map_enum_value(cls, value: tp.ArrayLike, look_for_type: tp.Optional[type] = None, **kwargs) -> tp.ArrayLike:
"""Map enumerated value(s)."""
if look_for_type is not None:
if isinstance(value, look_for_type):
return map_enum_fields(value, **kwargs)
return value
if isinstance(value, (CustomTemplate, Ref)):
return value
if isinstance(value, (Param, BCO, Default)):
attr_dct = value.asdict()
if isinstance(value, Param) and attr_dct["map_template"] is None:
attr_dct["map_template"] = RepFunc(lambda values: cls.map_enum_value(values, **kwargs))
elif not isinstance(value, Param):
attr_dct["value"] = cls.map_enum_value(attr_dct["value"], **kwargs)
return type(value)(**attr_dct)
if isinstance(value, index_dict):
return index_dict({k: cls.map_enum_value(v, **kwargs) for k, v in value.items()})
if isinstance(value, IdxSetterFactory):
value = value.get()
if not isinstance(value, IdxSetter):
raise ValueError("Index setter factory must return exactly one index setter")
if isinstance(value, IdxSetter):
return IdxSetter([(k, cls.map_enum_value(v, **kwargs)) for k, v in value.idx_items])
return map_enum_fields(value, **kwargs)
@classmethod
def prepare_td_obj(cls, td_obj: object, old_as_keys: bool = True) -> object:
"""Prepare a timedelta object for broadcasting."""
if isinstance(td_obj, Param):
return td_obj.map_value(cls.prepare_td_obj, old_as_keys=old_as_keys)
if isinstance(td_obj, (str, timedelta, pd.DateOffset, pd.Timedelta)):
td_obj = dt.to_timedelta64(td_obj)
elif isinstance(td_obj, pd.Index):
td_obj = td_obj.values
return td_obj
@classmethod
def prepare_dt_obj(
cls,
dt_obj: object,
old_as_keys: bool = True,
last_before: tp.Optional[bool] = None,
) -> object:
"""Prepare a datetime object for broadcasting."""
if isinstance(dt_obj, Param):
return dt_obj.map_value(cls.prepare_dt_obj, old_as_keys=old_as_keys)
if isinstance(dt_obj, (str, time, timedelta, pd.DateOffset, pd.Timedelta)):
def _apply_last_before(source_index, target_index, source_freq):
resampler = Resampler(source_index, target_index, source_freq=source_freq)
last_indices = resampler.last_before_target_index(incl_source=False)
source_rbound_ns = resampler.source_rbound_index.vbt.to_ns()
return np.where(last_indices != -1, source_rbound_ns[last_indices], -1)
def _to_dt(wrapper, _dt_obj=dt_obj, _last_before=last_before):
if _last_before is None:
_last_before = False
_dt_obj = dt.try_align_dt_to_index(_dt_obj, wrapper.index)
source_index = wrapper.index[wrapper.index < _dt_obj]
target_index = repeat_index(pd.Index([_dt_obj]), len(source_index))
if _last_before:
target_ns = _apply_last_before(source_index, target_index, wrapper.freq)
else:
target_ns = target_index.vbt.to_ns()
if len(target_ns) < len(wrapper.index):
target_ns = concat_arrays((target_ns, np.full(len(wrapper.index) - len(target_ns), -1)))
return target_ns
def _to_td(wrapper, _dt_obj=dt_obj, _last_before=last_before):
if _last_before is None:
_last_before = True
target_index = wrapper.index.vbt.to_period_ts(dt.to_freq(_dt_obj), shift=True)
if _last_before:
return _apply_last_before(wrapper.index, target_index, wrapper.freq)
return target_index.vbt.to_ns()
def _to_time(wrapper, _dt_obj=dt_obj, _last_before=last_before):
if _last_before is None:
_last_before = False
floor_index = wrapper.index.floor("1d") + dt.time_to_timedelta(_dt_obj)
target_index = floor_index.where(wrapper.index < floor_index, floor_index + pd.Timedelta(days=1))
if _last_before:
return _apply_last_before(wrapper.index, target_index, wrapper.freq)
return target_index.vbt.to_ns()
dt_obj_dt_template = RepFunc(_to_dt)
dt_obj_td_template = RepFunc(_to_td)
dt_obj_time_template = RepFunc(_to_time)
if isinstance(dt_obj, str):
try:
time.fromisoformat(dt_obj)
dt_obj = dt_obj_time_template
except Exception as e:
try:
dt.to_freq(dt_obj)
dt_obj = dt_obj_td_template
except Exception as e:
dt_obj = dt_obj_dt_template
elif isinstance(dt_obj, time):
dt_obj = dt_obj_time_template
else:
dt_obj = dt_obj_td_template
elif isinstance(dt_obj, pd.Index):
dt_obj = dt_obj.values
return dt_obj
def get_raw_arg_default(self, arg_name: str, is_dict: bool = False) -> tp.Any:
"""Get raw argument default."""
if self._settings_path is None:
if is_dict:
return {}
return None
value = self.get_setting(arg_name)
if is_dict and value is None:
return {}
return value
def get_raw_arg(self, arg_name: str, is_dict: bool = False, has_default: bool = True) -> tp.Any:
"""Get raw argument."""
value = self.config.get(arg_name, None)
if is_dict:
if has_default:
return merge_dicts(self.get_raw_arg_default(arg_name), value)
if value is None:
return {}
return value
if value is None and has_default:
return self.get_raw_arg_default(arg_name)
return value
@cachedproperty
def idx_setters(self) -> tp.Optional[tp.Dict[tp.Label, IdxSetter]]:
"""Index setters from resolving the argument `records`."""
arg_config = self.arg_config["records"]
records = self.get_raw_arg(
"records",
is_dict=arg_config.get("is_dict", False),
has_default=arg_config.get("has_default", True),
)
if records is None:
return None
if not isinstance(records, IdxRecords):
records = IdxRecords(records)
idx_setters = records.get()
for k in idx_setters:
if k in self.arg_config and not self.arg_config[k].get("broadcast", False):
raise ValueError(f"Field {k} is not broadcastable and cannot be included in records")
rename_fields = arg_config.get("rename_fields", {})
new_idx_setters = {}
for k, v in idx_setters.items():
if k in rename_fields:
k = rename_fields[k]
new_idx_setters[k] = v
return new_idx_setters
def get_arg_default(self, arg_name: str) -> tp.Any:
"""Get argument default according to the argument config."""
arg_config = self.arg_config[arg_name]
arg = self.get_raw_arg_default(
arg_name,
is_dict=arg_config.get("is_dict", False),
)
if arg is not None:
if len(arg_config.get("map_enum_kwargs", {})) > 0:
arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"])
if arg_config.get("is_td", False):
arg = self.prepare_td_obj(
arg,
old_as_keys=arg_config.get("old_as_keys", True),
)
if arg_config.get("is_dt", False):
arg = self.prepare_dt_obj(
arg,
old_as_keys=arg_config.get("old_as_keys", True),
last_before=arg_config.get("last_before", None),
)
return arg
def get_arg(self, arg_name: str, use_idx_setter: bool = True, use_default: bool = True) -> tp.Any:
"""Get mapped argument according to the argument config."""
arg_config = self.arg_config[arg_name]
if use_idx_setter and self.idx_setters is not None and arg_name in self.idx_setters:
arg = self.idx_setters[arg_name]
else:
arg = self.get_raw_arg(
arg_name,
is_dict=arg_config.get("is_dict", False),
has_default=arg_config.get("has_default", True) if use_default else False,
)
if arg is not None:
if len(arg_config.get("map_enum_kwargs", {})) > 0:
arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"])
if arg_config.get("is_td", False):
arg = self.prepare_td_obj(arg)
if arg_config.get("is_dt", False):
arg = self.prepare_dt_obj(arg, last_before=arg_config.get("last_before", None))
return arg
def __getitem__(self, arg_name) -> tp.Any:
return self.get_arg(arg_name)
def __iter__(self):
raise TypeError(f"'{type(self).__name__}' object is not iterable")
@classmethod
def prepare_td_arr(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a timedelta array."""
if td_arr.dtype == object:
if td_arr.ndim in (0, 1):
td_arr = pd.to_timedelta(td_arr)
if isinstance(td_arr, pd.Timedelta):
td_arr = td_arr.to_timedelta64()
else:
td_arr = td_arr.values
else:
td_arr_cols = []
for col in range(td_arr.shape[1]):
td_arr_col = pd.to_timedelta(td_arr[:, col])
td_arr_cols.append(td_arr_col.values)
td_arr = column_stack_arrays(td_arr_cols)
return td_arr
@classmethod
def prepare_dt_arr(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a datetime array."""
if dt_arr.dtype == object:
if dt_arr.ndim in (0, 1):
dt_arr = pd.to_datetime(dt_arr).tz_localize(None)
if isinstance(dt_arr, pd.Timestamp):
dt_arr = dt_arr.to_datetime64()
else:
dt_arr = dt_arr.values
else:
dt_arr_cols = []
for col in range(dt_arr.shape[1]):
dt_arr_col = pd.to_datetime(dt_arr[:, col]).tz_localize(None)
dt_arr_cols.append(dt_arr_col.values)
dt_arr = column_stack_arrays(dt_arr_cols)
return dt_arr
@classmethod
def td_arr_to_ns(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a timedelta array and convert it to nanoseconds."""
return dt.to_ns(cls.prepare_td_arr(td_arr))
@classmethod
def dt_arr_to_ns(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a datetime array and convert it to nanoseconds."""
return dt.to_ns(cls.prepare_dt_arr(dt_arr))
def prepare_post_arg(self, arg_name: str, value: tp.Optional[tp.ArrayLike] = None) -> object:
"""Prepare an argument after broadcasting and/or template substitution."""
if value is None:
if arg_name in self.post_args:
arg = self.post_args[arg_name]
else:
arg = getattr(self, "_pre_" + arg_name)
else:
arg = value
if arg is not None:
arg_config = self.arg_config[arg_name]
if arg_config.get("substitute_templates", False):
arg = substitute_templates(arg, self.template_context, eval_id=arg_name)
if "map_enum_kwargs" in arg_config:
arg = map_enum_fields(arg, **arg_config["map_enum_kwargs"])
if arg_config.get("is_td", False):
arg = self.td_arr_to_ns(arg)
if arg_config.get("is_dt", False):
arg = self.dt_arr_to_ns(arg)
if "type" in arg_config:
checks.assert_instance_of(arg, arg_config["type"], arg_name=arg_name)
if "subdtype" in arg_config:
checks.assert_subdtype(arg, arg_config["subdtype"], arg_name=arg_name)
return arg
@classmethod
def adapt_staticized_to_udf(cls, staticized: tp.Kwargs, func: tp.Union[str, tp.Callable], func_name: str) -> None:
"""Adapt `staticized` dictionary to a UDF."""
target_func_module = inspect.getmodule(staticized["func"])
if isinstance(func, tuple):
func, actual_func_name = func
else:
actual_func_name = None
if isinstance(func, (str, Path)):
if actual_func_name is None:
actual_func_name = func_name
if isinstance(func, str) and not func.endswith(".py") and hasattr(target_func_module, func):
staticized[f"{func_name}_block"] = func
return None
func = Path(func)
module_path = func.resolve()
else:
if actual_func_name is None:
actual_func_name = func.__name__
if inspect.getmodule(func) == target_func_module:
staticized[f"{func_name}_block"] = actual_func_name
return None
module = inspect.getmodule(func)
if not hasattr(module, "__file__"):
raise TypeError(f"{func_name} must be defined in a Python file")
module_path = Path(module.__file__).resolve()
if "import_lines" not in staticized:
staticized["import_lines"] = []
reload = staticized.get("reload", False)
staticized["import_lines"].extend(
[
f'{func_name}_path = r"{module_path}"',
f"globals().update(vbt.import_module_from_path({func_name}_path).__dict__, reload={reload})",
]
)
if actual_func_name != func_name:
staticized["import_lines"].append(f"{func_name} = {actual_func_name}")
@classmethod
def find_target_func(cls, target_func_name: str) -> tp.Callable:
"""Find target function by its name."""
raise NotImplementedError
@classmethod
def resolve_dynamic_target_func(cls, target_func_name: str, staticized: tp.KwargsLike) -> tp.Callable:
"""Resolve a dynamic target function."""
if staticized is None:
func = cls.find_target_func(target_func_name)
else:
if isinstance(staticized, dict):
staticized = dict(staticized)
module_path = suggest_module_path(
staticized.get("suggest_fname", target_func_name),
path=staticized.pop("path", None),
mkdir_kwargs=staticized.get("mkdir_kwargs", None),
)
if "new_func_name" not in staticized:
staticized["new_func_name"] = target_func_name
if staticized.pop("override", False) or not module_path.exists():
if "skip_func" not in staticized:
def _skip_func(out_lines, func_name):
to_skip = lambda x: f"def {func_name}" in x or x.startswith(f"{func_name}_path =")
return any(map(to_skip, out_lines))
staticized["skip_func"] = _skip_func
module_path = cut_and_save_func(path=module_path, **staticized)
if staticized.get("clear_cache", True):
remove_dir(module_path.parent / "__pycache__", with_contents=True, missing_ok=True)
reload = staticized.pop("reload", False)
module = import_module_from_path(module_path, reload=reload)
func = getattr(module, staticized["new_func_name"])
else:
func = staticized
return func
def set_seed(self) -> None:
"""Set seed."""
seed = self.seed
if seed is not None:
set_seed(seed)
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_template_context(self) -> tp.Kwargs:
"""Argument `template_context` before broadcasting."""
return merge_dicts(dict(preparer=self), self["template_context"])
# ############# Broadcasting ############# #
@cachedproperty
def pre_args(self) -> tp.Kwargs:
"""Arguments before broadcasting."""
pre_args = dict()
for k, v in self.arg_config.items():
if v.get("broadcast", False):
pre_args[k] = getattr(self, "_pre_" + k)
return pre_args
@cachedproperty
def args_to_broadcast(self) -> dict:
"""Arguments to broadcast."""
return merge_dicts(self.idx_setters, self.pre_args, self.broadcast_named_args)
@cachedproperty
def def_broadcast_kwargs(self) -> tp.Kwargs:
"""Default keyword arguments for broadcasting."""
return dict(
to_pd=False,
keep_flex=dict(cash_earnings=self.keep_inout_flex, _def=True),
wrapper_kwargs=dict(
freq=self._pre_freq,
group_by=self.group_by,
),
return_wrapper=True,
template_context=self._pre_template_context,
)
@cachedproperty
def broadcast_kwargs(self) -> tp.Kwargs:
"""Argument `broadcast_kwargs`."""
arg_broadcast_kwargs = defaultdict(dict)
for k, v in self.arg_config.items():
if v.get("broadcast", False):
broadcast_kwargs = v.get("broadcast_kwargs", None)
if broadcast_kwargs is None:
broadcast_kwargs = {}
for k2, v2 in broadcast_kwargs.items():
arg_broadcast_kwargs[k2][k] = v2
for k in self.args_to_broadcast:
new_fill_value = None
if k in self.pre_args:
fill_default = self.arg_config[k].get("fill_default", True)
if self.idx_setters is not None and k in self.idx_setters:
new_fill_value = self.get_arg(k, use_idx_setter=False, use_default=fill_default)
elif fill_default and self.arg_config[k].get("has_default", True):
new_fill_value = self.get_arg_default(k)
elif k in self.broadcast_named_args:
if self.idx_setters is not None and k in self.idx_setters:
new_fill_value = self.broadcast_named_args[k]
if new_fill_value is not None:
if not np.isscalar(new_fill_value):
raise TypeError(f"Argument '{k}' (and its default) must be a scalar when also provided via records")
if "reindex_kwargs" not in arg_broadcast_kwargs:
arg_broadcast_kwargs["reindex_kwargs"] = {}
if k not in arg_broadcast_kwargs["reindex_kwargs"]:
arg_broadcast_kwargs["reindex_kwargs"][k] = {}
arg_broadcast_kwargs["reindex_kwargs"][k]["fill_value"] = new_fill_value
return merge_dicts(
self.def_broadcast_kwargs,
dict(arg_broadcast_kwargs),
self["broadcast_kwargs"],
)
@cachedproperty
def broadcast_result(self) -> tp.Any:
"""Result of broadcasting."""
return broadcast(self.args_to_broadcast, **self.broadcast_kwargs)
@cachedproperty
def post_args(self) -> tp.Kwargs:
"""Arguments after broadcasting."""
return self.broadcast_result[0]
@cachedproperty
def post_broadcast_named_args(self) -> tp.Kwargs:
"""Custom arguments after broadcasting."""
if self.broadcast_named_args is None:
return dict()
post_broadcast_named_args = dict()
for k, v in self.post_args.items():
if k in self.broadcast_named_args:
post_broadcast_named_args[k] = v
elif self.idx_setters is not None and k in self.idx_setters and k not in self.pre_args:
post_broadcast_named_args[k] = v
return post_broadcast_named_args
@cachedproperty
def wrapper(self) -> ArrayWrapper:
"""Array wrapper."""
return self.broadcast_result[1]
@cachedproperty
def target_shape(self) -> tp.Shape:
"""Target shape."""
return self.wrapper.shape_2d
@cachedproperty
def index(self) -> tp.Array1d:
"""Index in nanosecond format."""
return self.wrapper.ns_index
@cachedproperty
def freq(self) -> int:
"""Frequency in nanosecond format."""
return self.wrapper.ns_freq
# ############# Template substitution ############# #
@cachedproperty
def template_context(self) -> tp.Kwargs:
"""Argument `template_context`."""
builtin_args = {}
for k, v in self.arg_config.items():
if v.get("broadcast", False):
builtin_args[k] = getattr(self, k)
return merge_dicts(
dict(
wrapper=self.wrapper,
target_shape=self.target_shape,
index=self.index,
freq=self.freq,
),
builtin_args,
self.post_broadcast_named_args,
self._pre_template_context,
)
# ############# Result ############# #
@cachedproperty
def target_func(self) -> tp.Optional[tp.Callable]:
"""Target function."""
return None
@cachedproperty
def target_arg_map(self) -> tp.Kwargs:
"""Map of the target arguments to the preparer attributes."""
return dict()
@cachedproperty
def target_args(self) -> tp.Optional[tp.Kwargs]:
"""Arguments to be passed to the target function."""
if self.target_func is not None:
target_arg_map = self.target_arg_map
func_arg_names = get_func_arg_names(self.target_func)
target_args = {}
for k in func_arg_names:
arg_attr = target_arg_map.get(k, k)
if arg_attr is not None and hasattr(self, arg_attr):
target_args[k] = getattr(self, arg_attr)
return target_args
return None
# ############# Docs ############# #
@classmethod
def build_arg_config_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build argument config documentation."""
if source_cls is None:
source_cls = BasePreparer
return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "arg_config").__doc__)).substitute(
{"arg_config": cls.arg_config.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_arg_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `BasePreparer.arg_config`."""
__pdoc__[cls.__name__ + ".arg_config"] = cls.build_arg_config_doc(source_cls=source_cls)
BasePreparer.override_arg_config_doc(__pdoc__)
</file>
<file path="base/reshaping.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for reshaping arrays.
Reshape functions transform a Pandas object/NumPy array in some way."""
import functools
import itertools
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base import indexes, wrapping, indexing
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.config import resolve_dict, merge_dicts
from vectorbtpro.utils.params import combine_params, Param
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"to_1d_shape",
"to_2d_shape",
"repeat_shape",
"tile_shape",
"to_1d_array",
"to_2d_array",
"to_2d_pr_array",
"to_2d_pc_array",
"to_1d_array_nb",
"to_2d_array_nb",
"to_2d_pr_array_nb",
"to_2d_pc_array_nb",
"broadcast_shapes",
"broadcast_array_to",
"broadcast_arrays",
"repeat",
"tile",
"align_pd_arrays",
"BCO",
"Default",
"Ref",
"broadcast",
"broadcast_to",
]
def to_tuple_shape(shape: tp.ShapeLike) -> tp.Shape:
"""Convert a shape-like object to a tuple."""
if checks.is_int(shape):
return (int(shape),)
return tuple(shape)
def to_1d_shape(shape: tp.ShapeLike) -> tp.Shape:
"""Convert a shape-like object to a 1-dim shape."""
shape = to_tuple_shape(shape)
if len(shape) == 0:
return (1,)
if len(shape) == 1:
return shape
if len(shape) == 2 and shape[1] == 1:
return (shape[0],)
raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 1 dimension")
def to_2d_shape(shape: tp.ShapeLike, expand_axis: int = 1) -> tp.Shape:
"""Convert a shape-like object to a 2-dim shape."""
shape = to_tuple_shape(shape)
if len(shape) == 0:
return 1, 1
if len(shape) == 1:
if expand_axis == 1:
return shape[0], 1
else:
return shape[0], 0
if len(shape) == 2:
return shape
raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 2 dimensions")
def repeat_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape:
"""Repeat shape `n` times along the specified axis."""
shape = to_tuple_shape(shape)
if len(shape) <= axis:
shape = tuple([shape[i] if i < len(shape) else 1 for i in range(axis + 1)])
return *shape[:axis], shape[axis] * n, *shape[axis + 1 :]
def tile_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape:
"""Tile shape `n` times along the specified axis.
Identical to `repeat_shape`. Exists purely for naming consistency."""
return repeat_shape(shape, n, axis=axis)
def index_to_series(obj: tp.Index, reset_index: bool = False) -> tp.Series:
"""Convert Index to Series."""
if reset_index:
return obj.to_series(index=pd.RangeIndex(stop=len(obj)))
return obj.to_series()
def index_to_frame(obj: tp.Index, reset_index: bool = False) -> tp.Frame:
"""Convert Index to DataFrame."""
if not isinstance(obj, pd.MultiIndex):
return index_to_series(obj, reset_index=reset_index).to_frame()
return obj.to_frame(index=not reset_index)
def mapping_to_series(obj: tp.MappingLike) -> tp.Series:
"""Convert a mapping-like object to Series."""
if checks.is_namedtuple(obj):
obj = obj._asdict()
return pd.Series(obj)
def to_any_array(obj: tp.ArrayLike, raw: bool = False, convert_index: bool = True) -> tp.AnyArray:
"""Convert any array-like object to an array.
Pandas objects are kept as-is unless `raw` is True."""
from vectorbtpro.indicators.factory import IndicatorBase
if isinstance(obj, IndicatorBase):
obj = obj.main_output
if not raw:
if checks.is_any_array(obj):
if convert_index and checks.is_index(obj):
return index_to_series(obj)
return obj
if checks.is_mapping_like(obj):
return mapping_to_series(obj)
return np.asarray(obj)
def to_pd_array(obj: tp.ArrayLike, convert_index: bool = True) -> tp.PandasArray:
"""Convert any array-like object to a Pandas object."""
from vectorbtpro.indicators.factory import IndicatorBase
if isinstance(obj, IndicatorBase):
obj = obj.main_output
if checks.is_pandas(obj):
if convert_index and checks.is_index(obj):
return index_to_series(obj)
return obj
if checks.is_mapping_like(obj):
return mapping_to_series(obj)
obj = np.asarray(obj)
if obj.ndim == 0:
obj = obj[None]
if obj.ndim == 1:
return pd.Series(obj)
if obj.ndim == 2:
return pd.DataFrame(obj)
raise ValueError("Wrong number of dimensions: cannot convert to Series or DataFrame")
def soft_to_ndim(obj: tp.ArrayLike, ndim: int, raw: bool = False) -> tp.AnyArray:
"""Try to softly bring `obj` to the specified number of dimensions `ndim` (max 2)."""
obj = to_any_array(obj, raw=raw)
if ndim == 1:
if obj.ndim == 2:
if obj.shape[1] == 1:
if checks.is_frame(obj):
return obj.iloc[:, 0]
return obj[:, 0] # downgrade
if ndim == 2:
if obj.ndim == 1:
if checks.is_series(obj):
return obj.to_frame()
return obj[:, None] # upgrade
return obj # do nothing
def to_1d(obj: tp.ArrayLike, raw: bool = False) -> tp.AnyArray1d:
"""Reshape argument to one dimension.
If `raw` is True, returns NumPy array.
If 2-dim, will collapse along axis 1 (i.e., DataFrame with one column to Series)."""
obj = to_any_array(obj, raw=raw)
if obj.ndim == 2:
if obj.shape[1] == 1:
if checks.is_frame(obj):
return obj.iloc[:, 0]
return obj[:, 0]
if obj.ndim == 1:
return obj
elif obj.ndim == 0:
return obj.reshape((1,))
raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 1 dimension")
to_1d_array = functools.partial(to_1d, raw=True)
"""`to_1d` with `raw` enabled."""
def to_2d(obj: tp.ArrayLike, raw: bool = False, expand_axis: int = 1) -> tp.AnyArray2d:
"""Reshape argument to two dimensions.
If `raw` is True, returns NumPy array.
If 1-dim, will expand along axis 1 (i.e., Series to DataFrame with one column)."""
obj = to_any_array(obj, raw=raw)
if obj.ndim == 2:
return obj
elif obj.ndim == 1:
if checks.is_series(obj):
if expand_axis == 0:
return pd.DataFrame(obj.values[None, :], columns=obj.index)
elif expand_axis == 1:
return obj.to_frame()
return np.expand_dims(obj, expand_axis)
elif obj.ndim == 0:
return obj.reshape((1, 1))
raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 2 dimensions")
to_2d_array = functools.partial(to_2d, raw=True)
"""`to_2d` with `raw` enabled."""
to_2d_pr_array = functools.partial(to_2d_array, expand_axis=1)
"""`to_2d_array` with `expand_axis=1`."""
to_2d_pc_array = functools.partial(to_2d_array, expand_axis=0)
"""`to_2d_array` with `expand_axis=0`."""
@register_jitted(cache=True)
def to_1d_array_nb(obj: tp.Array) -> tp.Array1d:
"""Resize array to one dimension."""
if obj.ndim == 0:
return np.expand_dims(obj, axis=0)
if obj.ndim == 1:
return obj
if obj.ndim == 2 and obj.shape[1] == 1:
return obj[:, 0]
raise ValueError("Array cannot be resized to one dimension")
@register_jitted(cache=True)
def to_2d_array_nb(obj: tp.Array, expand_axis: int = 1) -> tp.Array2d:
"""Resize array to two dimensions."""
if obj.ndim == 0:
return np.expand_dims(np.expand_dims(obj, axis=0), axis=0)
if obj.ndim == 1:
return np.expand_dims(obj, axis=expand_axis)
if obj.ndim == 2:
return obj
raise ValueError("Array cannot be resized to two dimensions")
@register_jitted(cache=True)
def to_2d_pr_array_nb(obj: tp.Array) -> tp.Array2d:
"""`to_2d_array_nb` with `expand_axis=1`."""
return to_2d_array_nb(obj, expand_axis=1)
@register_jitted(cache=True)
def to_2d_pc_array_nb(obj: tp.Array) -> tp.Array2d:
"""`to_2d_array_nb` with `expand_axis=0`."""
return to_2d_array_nb(obj, expand_axis=0)
def to_dict(obj: tp.ArrayLike, orient: str = "dict") -> dict:
"""Convert object to dict."""
obj = to_pd_array(obj)
if orient == "index_series":
return {obj.index[i]: obj.iloc[i] for i in range(len(obj.index))}
return obj.to_dict(orient)
def repeat(
obj: tp.ArrayLike,
n: int,
axis: int = 1,
raw: bool = False,
ignore_ranges: tp.Optional[bool] = None,
) -> tp.AnyArray:
"""Repeat `obj` `n` times along the specified axis."""
obj = to_any_array(obj, raw=raw)
if axis == 0:
if checks.is_pandas(obj):
new_index = indexes.repeat_index(obj.index, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=0), index=new_index)
return np.repeat(obj, n, axis=0)
elif axis == 1:
obj = to_2d(obj)
if checks.is_pandas(obj):
new_columns = indexes.repeat_index(obj.columns, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=1), columns=new_columns)
return np.repeat(obj, n, axis=1)
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {axis}")
def tile(
obj: tp.ArrayLike,
n: int,
axis: int = 1,
raw: bool = False,
ignore_ranges: tp.Optional[bool] = None,
) -> tp.AnyArray:
"""Tile `obj` `n` times along the specified axis."""
obj = to_any_array(obj, raw=raw)
if axis == 0:
if obj.ndim == 2:
if checks.is_pandas(obj):
new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (n, 1)), index=new_index)
return np.tile(obj, (n, 1))
if checks.is_pandas(obj):
new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, n), index=new_index)
return np.tile(obj, n)
elif axis == 1:
obj = to_2d(obj)
if checks.is_pandas(obj):
new_columns = indexes.tile_index(obj.columns, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (1, n)), columns=new_columns)
return np.tile(obj, (1, n))
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {axis}")
def broadcast_shapes(
*shapes: tp.ArrayLike,
axis: tp.Optional[tp.MaybeSequence[int]] = None,
expand_axis: tp.Optional[tp.MaybeSequence[int]] = None,
) -> tp.Tuple[tp.Shape, ...]:
"""Broadcast shape-like objects using vectorbt's broadcasting rules."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if expand_axis is None:
expand_axis = broadcasting_cfg["expand_axis"]
is_2d = False
for i, shape in enumerate(shapes):
shape = to_tuple_shape(shape)
if len(shape) == 2:
is_2d = True
break
new_shapes = []
for i, shape in enumerate(shapes):
shape = to_tuple_shape(shape)
if is_2d:
if checks.is_sequence(expand_axis):
_expand_axis = expand_axis[i]
else:
_expand_axis = expand_axis
new_shape = to_2d_shape(shape, expand_axis=_expand_axis)
else:
new_shape = to_1d_shape(shape)
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
if _axis is not None:
if _axis == 0:
if is_2d:
new_shape = (new_shape[0], 1)
else:
new_shape = (new_shape[0],)
elif _axis == 1:
if is_2d:
new_shape = (1, new_shape[1])
else:
new_shape = (1,)
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {_axis}")
new_shapes.append(new_shape)
return tuple(np.broadcast_shapes(*new_shapes))
def broadcast_array_to(
arr: tp.ArrayLike,
target_shape: tp.ShapeLike,
axis: tp.Optional[int] = None,
expand_axis: tp.Optional[int] = None,
) -> tp.Array:
"""Broadcast an array-like object to a target shape using vectorbt's broadcasting rules."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if expand_axis is None:
expand_axis = broadcasting_cfg["expand_axis"]
arr = np.asarray(arr)
target_shape = to_tuple_shape(target_shape)
if len(target_shape) not in (1, 2):
raise ValueError(f"Target shape must have either 1 or 2 dimensions, not {len(target_shape)}")
if len(target_shape) == 2:
new_arr = to_2d_array(arr, expand_axis=expand_axis)
else:
new_arr = to_1d_array(arr)
if axis is not None:
if axis == 0:
if len(target_shape) == 2:
target_shape = (target_shape[0], new_arr.shape[1])
else:
target_shape = (target_shape[0],)
elif axis == 1:
target_shape = (new_arr.shape[0], target_shape[1])
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {axis}")
return np.broadcast_to(new_arr, target_shape)
def broadcast_arrays(
*arrs: tp.ArrayLike,
target_shape: tp.Optional[tp.ShapeLike] = None,
axis: tp.Optional[tp.MaybeSequence[int]] = None,
expand_axis: tp.Optional[tp.MaybeSequence[int]] = None,
) -> tp.Tuple[tp.Array, ...]:
"""Broadcast array-like objects using vectorbt's broadcasting rules.
Optionally to a target shape."""
if target_shape is None:
shapes = []
for arr in arrs:
shapes.append(np.asarray(arr).shape)
target_shape = broadcast_shapes(*shapes, axis=axis, expand_axis=expand_axis)
new_arrs = []
for i, arr in enumerate(arrs):
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
else:
_axis = None
if expand_axis is not None:
if checks.is_sequence(expand_axis):
_expand_axis = expand_axis[i]
else:
_expand_axis = expand_axis
else:
_expand_axis = None
new_arr = broadcast_array_to(arr, target_shape, axis=_axis, expand_axis=_expand_axis)
new_arrs.append(new_arr)
return tuple(new_arrs)
IndexFromLike = tp.Union[None, str, int, tp.Any]
"""Any object that can be coerced into a `index_from` argument."""
def broadcast_index(
objs: tp.Sequence[tp.AnyArray],
to_shape: tp.Shape,
index_from: IndexFromLike = None,
axis: int = 0,
ignore_sr_names: tp.Optional[bool] = None,
ignore_ranges: tp.Optional[bool] = None,
check_index_names: tp.Optional[bool] = None,
**clean_index_kwargs,
) -> tp.Optional[tp.Index]:
"""Produce a broadcast index/columns.
Args:
objs (iterable of array_like): Array-like objects.
to_shape (tuple of int): Target shape.
index_from (any): Broadcasting rule for this index/these columns.
Accepts the following values:
* 'keep' or None - keep the original index/columns of the objects in `objs`
* 'stack' - stack different indexes/columns using `vectorbtpro.base.indexes.stack_indexes`
* 'strict' - ensure that all Pandas objects have the same index/columns
* 'reset' - reset any index/columns (they become a simple range)
* integer - use the index/columns of the i-th object in `objs`
* everything else will be converted to `pd.Index`
axis (int): Set to 0 for index and 1 for columns.
ignore_sr_names (bool): Whether to ignore Series names if they are in conflict.
Conflicting Series names are those that are different but not None.
ignore_ranges (bool): Whether to ignore indexes of type `pd.RangeIndex`.
check_index_names (bool): See `vectorbtpro.utils.checks.is_index_equal`.
**clean_index_kwargs: Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`.
For defaults, see `vectorbtpro._settings.broadcasting`.
!!! note
Series names are treated as columns with a single element but without a name.
If a column level without a name loses its meaning, better to convert Series to DataFrames
with one column prior to broadcasting. If the name of a Series is not that important,
better to drop it altogether by setting it to None.
"""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if ignore_sr_names is None:
ignore_sr_names = broadcasting_cfg["ignore_sr_names"]
if check_index_names is None:
check_index_names = broadcasting_cfg["check_index_names"]
index_str = "columns" if axis == 1 else "index"
to_shape_2d = (to_shape[0], 1) if len(to_shape) == 1 else to_shape
maxlen = to_shape_2d[1] if axis == 1 else to_shape_2d[0]
new_index = None
objs = list(objs)
if index_from is None or (isinstance(index_from, str) and index_from.lower() == "keep"):
return None
if isinstance(index_from, int):
if not checks.is_pandas(objs[index_from]):
raise TypeError(f"Argument under index {index_from} must be a pandas object")
new_index = indexes.get_index(objs[index_from], axis)
elif isinstance(index_from, str):
if index_from.lower() == "reset":
new_index = pd.RangeIndex(start=0, stop=maxlen, step=1)
elif index_from.lower() in ("stack", "strict"):
last_index = None
index_conflict = False
for obj in objs:
if checks.is_pandas(obj):
index = indexes.get_index(obj, axis)
if last_index is not None:
if not checks.is_index_equal(index, last_index, check_names=check_index_names):
index_conflict = True
last_index = index
continue
if not index_conflict:
new_index = last_index
else:
for obj in objs:
if checks.is_pandas(obj):
index = indexes.get_index(obj, axis)
if axis == 1 and checks.is_series(obj) and ignore_sr_names:
continue
if checks.is_default_index(index):
continue
if new_index is None:
new_index = index
else:
if checks.is_index_equal(index, new_index, check_names=check_index_names):
continue
if index_from.lower() == "strict":
raise ValueError(
f"Arrays have different index. Broadcasting {index_str} "
f"is not allowed when {index_str}_from=strict"
)
if len(index) != len(new_index):
if len(index) > 1 and len(new_index) > 1:
raise ValueError("Indexes could not be broadcast together")
if len(index) > len(new_index):
new_index = indexes.repeat_index(new_index, len(index), ignore_ranges=ignore_ranges)
elif len(index) < len(new_index):
index = indexes.repeat_index(index, len(new_index), ignore_ranges=ignore_ranges)
new_index = indexes.stack_indexes([new_index, index], **clean_index_kwargs)
else:
raise ValueError(f"Invalid value '{index_from}' for {'columns' if axis == 1 else 'index'}_from")
else:
if not isinstance(index_from, pd.Index):
index_from = pd.Index(index_from)
new_index = index_from
if new_index is not None:
if maxlen > len(new_index):
if isinstance(index_from, str) and index_from.lower() == "strict":
raise ValueError(f"Broadcasting {index_str} is not allowed when {index_str}_from=strict")
if maxlen > 1 and len(new_index) > 1:
raise ValueError("Indexes could not be broadcast together")
new_index = indexes.repeat_index(new_index, maxlen, ignore_ranges=ignore_ranges)
else:
new_index = pd.RangeIndex(start=0, stop=maxlen, step=1)
return new_index
def wrap_broadcasted(
new_obj: tp.Array,
old_obj: tp.Optional[tp.AnyArray] = None,
axis: tp.Optional[int] = None,
is_pd: bool = False,
new_index: tp.Optional[tp.Index] = None,
new_columns: tp.Optional[tp.Index] = None,
ignore_ranges: tp.Optional[bool] = None,
) -> tp.AnyArray:
"""If the newly brodcasted array was originally a Pandas object, make it Pandas object again
and assign it the newly broadcast index/columns."""
if is_pd:
if axis == 0:
new_columns = None
elif axis == 1:
new_index = None
if old_obj is not None and checks.is_pandas(old_obj):
if new_index is None:
old_index = indexes.get_index(old_obj, 0)
if old_obj.shape[0] == new_obj.shape[0]:
new_index = old_index
else:
new_index = indexes.repeat_index(old_index, new_obj.shape[0], ignore_ranges=ignore_ranges)
if new_columns is None:
old_columns = indexes.get_index(old_obj, 1)
new_ncols = new_obj.shape[1] if new_obj.ndim == 2 else 1
if len(old_columns) == new_ncols:
new_columns = old_columns
else:
new_columns = indexes.repeat_index(old_columns, new_ncols, ignore_ranges=ignore_ranges)
if new_obj.ndim == 2:
return pd.DataFrame(new_obj, index=new_index, columns=new_columns)
if new_columns is not None and len(new_columns) == 1:
name = new_columns[0]
if name == 0:
name = None
else:
name = None
return pd.Series(new_obj, index=new_index, name=name)
return new_obj
def align_pd_arrays(
*objs: tp.AnyArray,
align_index: bool = True,
align_columns: bool = True,
to_index: tp.Optional[tp.Index] = None,
to_columns: tp.Optional[tp.Index] = None,
axis: tp.Optional[tp.MaybeSequence[int]] = None,
reindex_kwargs: tp.KwargsLikeSequence = None,
) -> tp.MaybeTuple[tp.ArrayLike]:
"""Align Pandas arrays against common index and/or column levels using reindexing
and `vectorbtpro.base.indexes.align_indexes` respectively."""
objs = list(objs)
if align_index:
indexes_to_align = []
for i in range(len(objs)):
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
else:
_axis = None
if _axis in (None, 0):
if checks.is_pandas(objs[i]):
if not checks.is_default_index(objs[i].index):
indexes_to_align.append(i)
if (len(indexes_to_align) > 0 and to_index is not None) or len(indexes_to_align) > 1:
if to_index is None:
new_index = None
index_changed = False
for i in indexes_to_align:
arg_index = objs[i].index
if new_index is None:
new_index = arg_index
else:
if not checks.is_index_equal(new_index, arg_index):
if new_index.dtype != arg_index.dtype:
raise ValueError("Indexes to be aligned must have the same data type")
new_index = new_index.union(arg_index)
index_changed = True
else:
new_index = to_index
index_changed = True
if index_changed:
for i in indexes_to_align:
if to_index is None or not checks.is_index_equal(objs[i].index, to_index):
if objs[i].index.has_duplicates:
raise ValueError(f"Index at position {i} contains duplicates")
if not objs[i].index.is_monotonic_increasing:
raise ValueError(f"Index at position {i} is not monotonically increasing")
_reindex_kwargs = resolve_dict(reindex_kwargs, i=i)
was_bool = (isinstance(objs[i], pd.Series) and objs[i].dtype == "bool") or (
isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "bool").all()
)
objs[i] = objs[i].reindex(new_index, **_reindex_kwargs)
is_object = (isinstance(objs[i], pd.Series) and objs[i].dtype == "object") or (
isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "object").all()
)
if was_bool and is_object:
objs[i] = objs[i].astype(None)
if align_columns:
columns_to_align = []
for i in range(len(objs)):
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
else:
_axis = None
if _axis in (None, 1):
if checks.is_frame(objs[i]) and len(objs[i].columns) > 1:
if not checks.is_default_index(objs[i].columns):
columns_to_align.append(i)
if (len(columns_to_align) > 0 and to_columns is not None) or len(columns_to_align) > 1:
indexes_ = [objs[i].columns for i in columns_to_align]
if to_columns is not None:
indexes_.append(to_columns)
if len(set(map(len, indexes_))) > 1:
col_indices = indexes.align_indexes(*indexes_)
for i in columns_to_align:
objs[i] = objs[i].iloc[:, col_indices[columns_to_align.index(i)]]
if len(objs) == 1:
return objs[0]
return tuple(objs)
@define
class BCO(DefineMixin):
"""Class that represents an object passed to `broadcast`.
If any value is None, mostly defaults to the global value passed to `broadcast`."""
value: tp.Any = define.field()
"""Value of the object."""
axis: tp.Optional[int] = define.field(default=None)
"""Axis to broadcast.
Set to None to broadcast all axes."""
to_pd: tp.Optional[bool] = define.field(default=None)
"""Whether to convert the output array to a Pandas object."""
keep_flex: tp.Optional[bool] = define.field(default=None)
"""Whether to keep the raw version of the output for flexible indexing.
Only makes sure that the array can broadcast to the target shape."""
min_ndim: tp.Optional[int] = define.field(default=None)
"""Minimum number of dimensions."""
expand_axis: tp.Optional[int] = define.field(default=None)
"""Axis to expand if the array is 1-dim but the target shape is 2-dim."""
post_func: tp.Optional[tp.Callable] = define.field(default=None)
"""Function to post-process the output array."""
require_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None)
"""Keyword arguments passed to `np.require`."""
reindex_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None)
"""Keyword arguments passed to `pd.DataFrame.reindex`."""
merge_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None)
"""Keyword arguments passed to `vectorbtpro.base.merging.column_stack_merge`."""
context: tp.KwargsLike = define.field(default=None)
"""Context used in evaluation of templates.
Will be merged over `template_context`."""
@define
class Default(DefineMixin):
"""Class for wrapping default values."""
value: tp.Any = define.field()
"""Default value."""
@define
class Ref(DefineMixin):
"""Class for wrapping references to other values."""
key: tp.Hashable = define.field()
"""Reference to another key."""
def resolve_ref(dct: dict, k: tp.Hashable, inside_bco: bool = False, keep_wrap_default: bool = False) -> tp.Any:
"""Resolve a potential reference."""
v = dct[k]
is_default = False
if isinstance(v, Default):
v = v.value
is_default = True
if isinstance(v, Ref):
new_v = resolve_ref(dct, v.key, inside_bco=inside_bco)
if keep_wrap_default and is_default:
return Default(new_v)
return new_v
if isinstance(v, BCO) and inside_bco:
v = v.value
is_default = False
if isinstance(v, Default):
v = v.value
is_default = True
if isinstance(v, Ref):
new_v = resolve_ref(dct, v.key, inside_bco=inside_bco)
if keep_wrap_default and is_default:
return Default(new_v)
return new_v
return v
def broadcast(
*objs,
to_shape: tp.Optional[tp.ShapeLike] = None,
align_index: tp.Optional[bool] = None,
align_columns: tp.Optional[bool] = None,
index_from: tp.Optional[IndexFromLike] = None,
columns_from: tp.Optional[IndexFromLike] = None,
to_frame: tp.Optional[bool] = None,
axis: tp.Optional[tp.MaybeMappingSequence[int]] = None,
to_pd: tp.Optional[tp.MaybeMappingSequence[bool]] = None,
keep_flex: tp.MaybeMappingSequence[tp.Optional[bool]] = None,
min_ndim: tp.MaybeMappingSequence[tp.Optional[int]] = None,
expand_axis: tp.MaybeMappingSequence[tp.Optional[int]] = None,
post_func: tp.MaybeMappingSequence[tp.Optional[tp.Callable]] = None,
require_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None,
reindex_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None,
merge_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None,
tile: tp.Union[None, int, tp.IndexLike] = None,
random_subset: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
keep_wrap_default: tp.Optional[bool] = None,
return_wrapper: bool = False,
wrapper_kwargs: tp.KwargsLike = None,
ignore_sr_names: tp.Optional[bool] = None,
ignore_ranges: tp.Optional[bool] = None,
check_index_names: tp.Optional[bool] = None,
clean_index_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
) -> tp.Any:
"""Bring any array-like object in `objs` to the same shape by using NumPy-like broadcasting.
See [Broadcasting](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
!!! important
The major difference to NumPy is that one-dimensional arrays will always broadcast against the row axis!
Can broadcast Pandas objects by broadcasting their index/columns with `broadcast_index`.
Args:
*objs: Objects to broadcast.
If the first and only argument is a mapping, will return a dict.
Allows using `BCO`, `Ref`, `Default`, `vectorbtpro.utils.params.Param`,
`vectorbtpro.base.indexing.index_dict`, `vectorbtpro.base.indexing.IdxSetter`,
`vectorbtpro.base.indexing.IdxSetterFactory`, and templates.
If an index dictionary, fills using `vectorbtpro.base.wrapping.ArrayWrapper.fill_and_set`.
to_shape (tuple of int): Target shape. If set, will broadcast every object in `objs` to `to_shape`.
align_index (bool): Whether to align index of Pandas objects using union.
Pass None to use the default.
align_columns (bool): Whether to align columns of Pandas objects using multi-index.
Pass None to use the default.
index_from (any): Broadcasting rule for index.
Pass None to use the default.
columns_from (any): Broadcasting rule for columns.
Pass None to use the default.
to_frame (bool): Whether to convert all Series to DataFrames.
axis (int, sequence or mapping): See `BCO.axis`.
to_pd (bool, sequence or mapping): See `BCO.to_pd`.
If None, converts only if there is at least one Pandas object among them.
keep_flex (bool, sequence or mapping): See `BCO.keep_flex`.
min_ndim (int, sequence or mapping): See `BCO.min_ndim`.
If None, becomes 2 if `keep_flex` is True, otherwise 1.
expand_axis (int, sequence or mapping): See `BCO.expand_axis`.
post_func (callable, sequence or mapping): See `BCO.post_func`.
Applied only when `keep_flex` is False.
require_kwargs (dict, sequence or mapping): See `BCO.require_kwargs`.
This key will be merged with any argument-specific dict. If the mapping contains all keys in
`np.require`, it will be applied to all objects.
reindex_kwargs (dict, sequence or mapping): See `BCO.reindex_kwargs`.
This key will be merged with any argument-specific dict. If the mapping contains all keys in
`pd.DataFrame.reindex`, it will be applied to all objects.
merge_kwargs (dict, sequence or mapping): See `BCO.merge_kwargs`.
This key will be merged with any argument-specific dict. If the mapping contains all keys in
`pd.DataFrame.merge`, it will be applied to all objects.
tile (int or index_like): Tile the final object by the number of times or index.
random_subset (int): Select a random subset of parameter values.
Seed can be set using NumPy before calling this function.
seed (int): Seed to make output deterministic.
keep_wrap_default (bool): Whether to keep wrapping with `vectorbtpro.base.reshaping.Default`.
return_wrapper (bool): Whether to also return the wrapper associated with the operation.
wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`.
ignore_sr_names (bool): See `broadcast_index`.
ignore_ranges (bool): See `broadcast_index`.
check_index_names (bool): See `broadcast_index`.
clean_index_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`.
template_context (dict): Context used to substitute templates.
For defaults, see `vectorbtpro._settings.broadcasting`.
Any keyword argument that can be associated with an object can be passed as
* a const that is applied to all objects,
* a sequence with value per object, and
* a mapping with value per object name and the special key `_def` denoting the default value.
Additionally, any object can be passed wrapped with `BCO`, which ibutes will override
any of the above arguments if not None.
Usage:
* Without broadcasting index and columns:
```pycon
>>> from vectorbtpro import *
>>> v = 0
>>> a = np.array([1, 2, 3])
>>> sr = pd.Series([1, 2, 3], index=pd.Index(['x', 'y', 'z']), name='a')
>>> df = pd.DataFrame(
... [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
... index=pd.Index(['x2', 'y2', 'z2']),
... columns=pd.Index(['a2', 'b2', 'c2']),
... )
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from='keep',
... columns_from='keep',
... align_index=False
... ): print(i)
0 1 2
0 0 0 0
1 0 0 0
2 0 0 0
0 1 2
0 1 2 3
1 1 2 3
2 1 2 3
a a a
x 1 1 1
y 2 2 2
z 3 3 3
a2 b2 c2
x2 1 2 3
y2 4 5 6
z2 7 8 9
```
* Take index and columns from the argument at specific position:
```pycon
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from=2,
... columns_from=3,
... align_index=False
... ): print(i)
a2 b2 c2
x 0 0 0
y 0 0 0
z 0 0 0
a2 b2 c2
x 1 2 3
y 1 2 3
z 1 2 3
a2 b2 c2
x 1 1 1
y 2 2 2
z 3 3 3
a2 b2 c2
x 1 2 3
y 4 5 6
z 7 8 9
```
* Broadcast index and columns through stacking:
```pycon
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from='stack',
... columns_from='stack',
... align_index=False
... ): print(i)
a2 b2 c2
x x2 0 0 0
y y2 0 0 0
z z2 0 0 0
a2 b2 c2
x x2 1 2 3
y y2 1 2 3
z z2 1 2 3
a2 b2 c2
x x2 1 1 1
y y2 2 2 2
z z2 3 3 3
a2 b2 c2
x x2 1 2 3
y y2 4 5 6
z z2 7 8 9
```
* Set index and columns manually:
```pycon
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from=['a', 'b', 'c'],
... columns_from=['d', 'e', 'f'],
... align_index=False
... ): print(i)
d e f
a 0 0 0
b 0 0 0
c 0 0 0
d e f
a 1 2 3
b 1 2 3
c 1 2 3
d e f
a 1 1 1
b 2 2 2
c 3 3 3
d e f
a 1 2 3
b 4 5 6
c 7 8 9
```
* Pass arguments as a mapping returns a mapping:
```pycon
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df),
... index_from='stack',
... align_index=False
... )
{'v': a2 b2 c2
x x2 0 0 0
y y2 0 0 0
z z2 0 0 0,
'a': a2 b2 c2
x x2 1 2 3
y y2 1 2 3
z z2 1 2 3,
'sr': a2 b2 c2
x x2 1 1 1
y y2 2 2 2
z z2 3 3 3,
'df': a2 b2 c2
x x2 1 2 3
y y2 4 5 6
z z2 7 8 9}
```
* Keep all results in a format suitable for flexible indexing apart from one:
```pycon
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df),
... index_from='stack',
... keep_flex=dict(_def=True, df=False),
... require_kwargs=dict(df=dict(dtype=float)),
... align_index=False
... )
{'v': array([0]),
'a': array([1, 2, 3]),
'sr': array([[1],
[2],
[3]]),
'df': a2 b2 c2
x x2 1.0 2.0 3.0
y y2 4.0 5.0 6.0
z z2 7.0 8.0 9.0}
```
* Specify arguments per object using `BCO`:
```pycon
>>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float))
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df_bco),
... index_from='stack',
... keep_flex=True,
... align_index=False
... )
{'v': array([0]),
'a': array([1, 2, 3]),
'sr': array([[1],
[2],
[3]]),
'df': a2 b2 c2
x x2 1.0 2.0 3.0
y y2 4.0 5.0 6.0
z z2 7.0 8.0 9.0}
```
* Introduce a parameter that should build a Cartesian product of its values and other objects:
```pycon
>>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float))
>>> p_bco = vbt.BCO(pd.Param([1, 2, 3], name='my_p'))
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df_bco, p=p_bco),
... index_from='stack',
... keep_flex=True,
... align_index=False
... )
{'v': array([0]),
'a': array([1, 2, 3, 1, 2, 3, 1, 2, 3]),
'sr': array([[1],
[2],
[3]]),
'df': my_p 1 2 3
a2 b2 c2 a2 b2 c2 a2 b2 c2
x x2 1.0 2.0 3.0 1.0 2.0 3.0 1.0 2.0 3.0
y y2 4.0 5.0 6.0 4.0 5.0 6.0 4.0 5.0 6.0
z z2 7.0 8.0 9.0 7.0 8.0 9.0 7.0 8.0 9.0,
'p': array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
[1, 1, 1, 2, 2, 2, 3, 3, 3],
[1, 1, 1, 2, 2, 2, 3, 3, 3]])}
```
* Build a Cartesian product of all parameters:
```pycon
>>> vbt.broadcast(
... dict(
... a=vbt.Param([1, 2, 3]),
... b=vbt.Param(['x', 'y']),
... c=vbt.Param([False, True])
... )
... )
{'a': array([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]]),
'b': array([['x', 'x', 'y', 'y', 'x', 'x', 'y', 'y', 'x', 'x', 'y', 'y']], dtype='<U1'),
'c': array([[False, True, False, True, False, True, False, True, False, True, False, True]])}
```
* Build a Cartesian product of two groups of parameters - (a, d) and (b, c):
```pycon
>>> vbt.broadcast(
... dict(
... a=vbt.Param([1, 2, 3], level=0),
... b=vbt.Param(['x', 'y'], level=1),
... d=vbt.Param([100., 200., 300.], level=0),
... c=vbt.Param([False, True], level=1)
... )
... )
{'a': array([[1, 1, 2, 2, 3, 3]]),
'b': array([['x', 'y', 'x', 'y', 'x', 'y']], dtype='<U1'),
'd': array([[100., 100., 200., 200., 300., 300.]]),
'c': array([[False, True, False, True, False, True]])}
```
* Select a random subset of parameter combinations:
```pycon
>>> vbt.broadcast(
... dict(
... a=vbt.Param([1, 2, 3]),
... b=vbt.Param(['x', 'y']),
... c=vbt.Param([False, True])
... ),
... random_subset=5,
... seed=42
... )
{'a': array([[1, 2, 3, 3, 3]]),
'b': array([['x', 'x', 'x', 'x', 'y']], dtype='<U1'),
'c': array([[False, True, False, True, False]])}
```
"""
# Get defaults
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if align_index is None:
align_index = broadcasting_cfg["align_index"]
if align_columns is None:
align_columns = broadcasting_cfg["align_columns"]
if index_from is None:
index_from = broadcasting_cfg["index_from"]
if columns_from is None:
columns_from = broadcasting_cfg["columns_from"]
if keep_wrap_default is None:
keep_wrap_default = broadcasting_cfg["keep_wrap_default"]
require_kwargs_per_obj = True
if require_kwargs is not None and checks.is_mapping(require_kwargs):
require_arg_names = get_func_arg_names(np.require)
if set(require_kwargs) <= set(require_arg_names):
require_kwargs_per_obj = False
reindex_kwargs_per_obj = True
if reindex_kwargs is not None and checks.is_mapping(reindex_kwargs):
reindex_arg_names = get_func_arg_names(pd.DataFrame.reindex)
if set(reindex_kwargs) <= set(reindex_arg_names):
reindex_kwargs_per_obj = False
merge_kwargs_per_obj = True
if merge_kwargs is not None and checks.is_mapping(merge_kwargs):
merge_arg_names = get_func_arg_names(pd.DataFrame.merge)
if set(merge_kwargs) <= set(merge_arg_names):
merge_kwargs_per_obj = False
if clean_index_kwargs is None:
clean_index_kwargs = {}
if checks.is_mapping(objs[0]) and not isinstance(objs[0], indexing.index_dict):
if len(objs) > 1:
raise ValueError("Only one argument is allowed when passing a mapping")
all_keys = list(dict(objs[0]).keys())
objs = list(objs[0].values())
return_dict = True
else:
objs = list(objs)
all_keys = list(range(len(objs)))
return_dict = False
def _resolve_arg(obj: tp.Any, arg_name: str, global_value: tp.Any, default_value: tp.Any) -> tp.Any:
if isinstance(obj, BCO) and getattr(obj, arg_name) is not None:
return getattr(obj, arg_name)
if checks.is_mapping(global_value):
return global_value.get(k, global_value.get("_def", default_value))
if checks.is_sequence(global_value):
return global_value[i]
return global_value
# Build BCO instances
none_keys = set()
default_keys = set()
param_keys = set()
special_keys = set()
bco_instances = {}
pool = dict(zip(all_keys, objs))
for i, k in enumerate(all_keys):
obj = objs[i]
if isinstance(obj, Default):
obj = obj.value
default_keys.add(k)
if isinstance(obj, Ref):
obj = resolve_ref(pool, k)
if isinstance(obj, BCO):
value = obj.value
else:
value = obj
if isinstance(value, Default):
value = value.value
default_keys.add(k)
if isinstance(value, Ref):
value = resolve_ref(pool, k, inside_bco=True)
if value is None:
none_keys.add(k)
continue
_axis = _resolve_arg(obj, "axis", axis, None)
_to_pd = _resolve_arg(obj, "to_pd", to_pd, None)
_keep_flex = _resolve_arg(obj, "keep_flex", keep_flex, None)
if _keep_flex is None:
_keep_flex = broadcasting_cfg["keep_flex"]
_min_ndim = _resolve_arg(obj, "min_ndim", min_ndim, None)
if _min_ndim is None:
_min_ndim = broadcasting_cfg["min_ndim"]
_expand_axis = _resolve_arg(obj, "expand_axis", expand_axis, None)
if _expand_axis is None:
_expand_axis = broadcasting_cfg["expand_axis"]
_post_func = _resolve_arg(obj, "post_func", post_func, None)
if isinstance(obj, BCO) and obj.require_kwargs is not None:
_require_kwargs = obj.require_kwargs
else:
_require_kwargs = None
if checks.is_mapping(require_kwargs) and require_kwargs_per_obj:
_require_kwargs = merge_dicts(
require_kwargs.get("_def", None),
require_kwargs.get(k, None),
_require_kwargs,
)
elif checks.is_sequence(require_kwargs) and require_kwargs_per_obj:
_require_kwargs = merge_dicts(require_kwargs[i], _require_kwargs)
else:
_require_kwargs = merge_dicts(require_kwargs, _require_kwargs)
if isinstance(obj, BCO) and obj.reindex_kwargs is not None:
_reindex_kwargs = obj.reindex_kwargs
else:
_reindex_kwargs = None
if checks.is_mapping(reindex_kwargs) and reindex_kwargs_per_obj:
_reindex_kwargs = merge_dicts(
reindex_kwargs.get("_def", None),
reindex_kwargs.get(k, None),
_reindex_kwargs,
)
elif checks.is_sequence(reindex_kwargs) and reindex_kwargs_per_obj:
_reindex_kwargs = merge_dicts(reindex_kwargs[i], _reindex_kwargs)
else:
_reindex_kwargs = merge_dicts(reindex_kwargs, _reindex_kwargs)
if isinstance(obj, BCO) and obj.merge_kwargs is not None:
_merge_kwargs = obj.merge_kwargs
else:
_merge_kwargs = None
if checks.is_mapping(merge_kwargs) and merge_kwargs_per_obj:
_merge_kwargs = merge_dicts(
merge_kwargs.get("_def", None),
merge_kwargs.get(k, None),
_merge_kwargs,
)
elif checks.is_sequence(merge_kwargs) and merge_kwargs_per_obj:
_merge_kwargs = merge_dicts(merge_kwargs[i], _merge_kwargs)
else:
_merge_kwargs = merge_dicts(merge_kwargs, _merge_kwargs)
if isinstance(obj, BCO):
_context = merge_dicts(template_context, obj.context)
else:
_context = template_context
if isinstance(value, Param):
param_keys.add(k)
elif isinstance(value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory, CustomTemplate)):
special_keys.add(k)
else:
value = to_any_array(value)
bco_instances[k] = BCO(
value,
axis=_axis,
to_pd=_to_pd,
keep_flex=_keep_flex,
min_ndim=_min_ndim,
expand_axis=_expand_axis,
post_func=_post_func,
require_kwargs=_require_kwargs,
reindex_kwargs=_reindex_kwargs,
merge_kwargs=_merge_kwargs,
context=_context,
)
# Check whether we should broadcast Pandas metadata and work on 2-dim data
is_pd = False
is_2d = False
old_objs = {}
obj_axis = {}
obj_reindex_kwargs = {}
for k, bco_obj in bco_instances.items():
if k in none_keys or k in param_keys or k in special_keys:
continue
obj = bco_obj.value
if obj.ndim > 1:
is_2d = True
if checks.is_pandas(obj):
is_pd = True
if bco_obj.to_pd is not None and bco_obj.to_pd:
is_pd = True
old_objs[k] = obj
obj_axis[k] = bco_obj.axis
obj_reindex_kwargs[k] = bco_obj.reindex_kwargs
if to_shape is not None:
if isinstance(to_shape, int):
to_shape = (to_shape,)
if len(to_shape) > 1:
is_2d = True
if to_frame is not None:
is_2d = to_frame
if to_pd is not None:
is_pd = to_pd or (return_wrapper and is_pd)
# Align pandas arrays
if index_from is not None and not isinstance(index_from, (int, str, pd.Index)):
index_from = pd.Index(index_from)
if columns_from is not None and not isinstance(columns_from, (int, str, pd.Index)):
columns_from = pd.Index(columns_from)
aligned_objs = align_pd_arrays(
*old_objs.values(),
align_index=align_index,
align_columns=align_columns,
to_index=index_from if isinstance(index_from, pd.Index) else None,
to_columns=columns_from if isinstance(columns_from, pd.Index) else None,
axis=list(obj_axis.values()),
reindex_kwargs=list(obj_reindex_kwargs.values()),
)
if not isinstance(aligned_objs, tuple):
aligned_objs = (aligned_objs,)
aligned_objs = dict(zip(old_objs.keys(), aligned_objs))
# Convert to NumPy
ready_objs = {}
for k, obj in aligned_objs.items():
_expand_axis = bco_instances[k].expand_axis
new_obj = np.asarray(obj)
if is_2d and new_obj.ndim == 1:
if isinstance(obj, pd.Series):
new_obj = new_obj[:, None]
else:
new_obj = np.expand_dims(new_obj, _expand_axis)
ready_objs[k] = new_obj
# Get final shape
if to_shape is None:
try:
to_shape = broadcast_shapes(
*map(lambda x: x.shape, ready_objs.values()),
axis=list(obj_axis.values()),
)
except ValueError:
arr_shapes = {}
for i, k in enumerate(bco_instances):
if k in none_keys or k in param_keys or k in special_keys:
continue
if len(ready_objs[k].shape) > 0:
arr_shapes[k] = ready_objs[k].shape
raise ValueError("Could not broadcast shapes: %s" % str(arr_shapes))
if not isinstance(to_shape, tuple):
to_shape = (to_shape,)
if len(to_shape) == 0:
to_shape = (1,)
to_shape_2d = to_shape if len(to_shape) > 1 else (*to_shape, 1)
if is_pd:
# Decide on index and columns
# NOTE: Important to pass aligned_objs, not ready_objs, to preserve original shape info
new_index = broadcast_index(
[v for k, v in aligned_objs.items() if obj_axis[k] in (None, 0)],
to_shape,
index_from=index_from,
axis=0,
ignore_sr_names=ignore_sr_names,
ignore_ranges=ignore_ranges,
check_index_names=check_index_names,
**clean_index_kwargs,
)
new_columns = broadcast_index(
[v for k, v in aligned_objs.items() if obj_axis[k] in (None, 1)],
to_shape,
index_from=columns_from,
axis=1,
ignore_sr_names=ignore_sr_names,
ignore_ranges=ignore_ranges,
check_index_names=check_index_names,
**clean_index_kwargs,
)
else:
new_index = pd.RangeIndex(stop=to_shape_2d[0])
new_columns = pd.RangeIndex(stop=to_shape_2d[1])
# Build a product
param_product = None
param_columns = None
n_params = 0
if len(param_keys) > 0:
# Combine parameters
param_dct = {}
for k, bco_obj in bco_instances.items():
if k not in param_keys:
continue
param_dct[k] = bco_obj.value
param_product, param_columns = combine_params(
param_dct,
random_subset=random_subset,
seed=seed,
clean_index_kwargs=clean_index_kwargs,
)
n_params = len(param_columns)
# Combine parameter columns with new columns
if param_columns is not None and new_columns is not None:
new_columns = indexes.combine_indexes([param_columns, new_columns], **clean_index_kwargs)
# Tile
if tile is not None:
if isinstance(tile, int):
if new_columns is not None:
new_columns = indexes.tile_index(new_columns, tile)
else:
if new_columns is not None:
new_columns = indexes.combine_indexes([tile, new_columns], **clean_index_kwargs)
tile = len(tile)
n_params = max(n_params, 1) * tile
# Build wrapper
if n_params == 0:
new_shape = to_shape
else:
new_shape = (to_shape_2d[0], to_shape_2d[1] * n_params)
wrapper = wrapping.ArrayWrapper.from_shape(
new_shape,
**merge_dicts(
dict(
index=new_index,
columns=new_columns,
),
wrapper_kwargs,
),
)
def _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis):
if _min_ndim is None:
if _keep_flex:
_min_ndim = 2
else:
_min_ndim = 1
if _min_ndim not in (1, 2):
raise ValueError("Argument min_ndim must be either 1 or 2")
if _min_ndim in (1, 2) and new_obj.ndim == 0:
new_obj = new_obj[None]
if _min_ndim == 2 and new_obj.ndim == 1:
if len(to_shape) == 1:
new_obj = new_obj[:, None]
else:
new_obj = np.expand_dims(new_obj, _expand_axis)
return new_obj
# Perform broadcasting
aligned_objs2 = {}
new_objs = {}
for i, k in enumerate(all_keys):
if k in none_keys or k in special_keys:
continue
_keep_flex = bco_instances[k].keep_flex
_min_ndim = bco_instances[k].min_ndim
_axis = bco_instances[k].axis
_expand_axis = bco_instances[k].expand_axis
_merge_kwargs = bco_instances[k].merge_kwargs
_context = bco_instances[k].context
must_reset_index = _merge_kwargs.get("reset_index", None) not in (None, False)
_reindex_kwargs = resolve_dict(bco_instances[k].reindex_kwargs)
_fill_value = _reindex_kwargs.get("fill_value", np.nan)
if k in param_keys:
# Broadcast parameters
from vectorbtpro.base.merging import column_stack_merge
if _axis == 0:
raise ValueError("Parameters do not support broadcasting with axis=0")
obj = param_product[k]
new_obj = []
any_needs_broadcasting = False
all_forced_broadcast = True
for o in obj:
if isinstance(o, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)):
o = wrapper.fill_and_set(
o,
fill_value=_fill_value,
keep_flex=_keep_flex,
)
elif isinstance(o, CustomTemplate):
context = merge_dicts(
dict(
bco_instances=bco_instances,
wrapper=wrapper,
obj_name=k,
bco=bco_instances[k],
),
_context,
)
o = o.substitute(context, eval_id="broadcast")
o = to_2d_array(o)
if not _keep_flex:
needs_broadcasting = True
elif o.shape[0] > 1:
needs_broadcasting = True
elif o.shape[1] > 1 and o.shape[1] != to_shape_2d[1]:
needs_broadcasting = True
else:
needs_broadcasting = False
if needs_broadcasting:
any_needs_broadcasting = True
o = broadcast_array_to(o, to_shape_2d, axis=_axis)
elif o.size == 1:
all_forced_broadcast = False
o = np.repeat(o, to_shape_2d[1], axis=1)
else:
all_forced_broadcast = False
new_obj.append(o)
if any_needs_broadcasting and not all_forced_broadcast:
new_obj2 = []
for o in new_obj:
if o.shape[1] != to_shape_2d[1] or (not must_reset_index and o.shape[0] != to_shape_2d[0]):
o = broadcast_array_to(o, to_shape_2d, axis=_axis)
new_obj2.append(o)
new_obj = new_obj2
obj = column_stack_merge(new_obj, **_merge_kwargs)
if tile is not None:
obj = np.tile(obj, (1, tile))
old_obj = obj
new_obj = obj
else:
# Broadcast regular objects
old_obj = aligned_objs[k]
new_obj = ready_objs[k]
if _axis in (None, 0) and new_obj.ndim >= 1 and new_obj.shape[0] > 1 and new_obj.shape[0] != to_shape[0]:
raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}")
if _axis in (None, 1) and new_obj.ndim == 2 and new_obj.shape[1] > 1 and new_obj.shape[1] != to_shape[1]:
raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}")
if _keep_flex:
if n_params > 0 and _axis in (None, 1):
if len(to_shape) == 1:
if new_obj.ndim == 1 and new_obj.shape[0] > 1:
new_obj = new_obj[:, None] # product changes is_2d behavior
else:
if new_obj.ndim == 1 and new_obj.shape[0] > 1:
new_obj = np.tile(new_obj, n_params)
elif new_obj.ndim == 2 and new_obj.shape[1] > 1:
new_obj = np.tile(new_obj, (1, n_params))
else:
new_obj = broadcast_array_to(new_obj, to_shape, axis=_axis)
if n_params > 0 and _axis in (None, 1):
if new_obj.ndim == 1:
new_obj = new_obj[:, None] # product changes is_2d behavior
new_obj = np.tile(new_obj, (1, n_params))
new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis)
aligned_objs2[k] = old_obj
new_objs[k] = new_obj
# Resolve special objects
new_objs2 = {}
for i, k in enumerate(all_keys):
if k in none_keys:
continue
if k in special_keys:
bco = bco_instances[k]
if isinstance(bco.value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)):
_is_pd = bco.to_pd
if _is_pd is None:
_is_pd = is_pd
_keep_flex = bco.keep_flex
_min_ndim = bco.min_ndim
_expand_axis = bco.expand_axis
_reindex_kwargs = resolve_dict(bco.reindex_kwargs)
_fill_value = _reindex_kwargs.get("fill_value", np.nan)
new_obj = wrapper.fill_and_set(
bco.value,
fill_value=_fill_value,
keep_flex=_keep_flex,
)
if not _is_pd and not _keep_flex:
new_obj = new_obj.values
new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis)
elif isinstance(bco.value, CustomTemplate):
context = merge_dicts(
dict(
bco_instances=bco_instances,
new_objs=new_objs,
wrapper=wrapper,
obj_name=k,
bco=bco,
),
bco.context,
)
new_obj = bco.value.substitute(context, eval_id="broadcast")
else:
raise TypeError(f"Special type {type(bco.value)} is not supported")
else:
new_obj = new_objs[k]
# Force to match requirements
new_obj = np.require(new_obj, **resolve_dict(bco_instances[k].require_kwargs))
new_objs2[k] = new_obj
# Perform wrapping and post-processing
new_objs3 = {}
for i, k in enumerate(all_keys):
if k in none_keys:
continue
new_obj = new_objs2[k]
_axis = bco_instances[k].axis
_keep_flex = bco_instances[k].keep_flex
if not _keep_flex:
# Wrap array
_is_pd = bco_instances[k].to_pd
if _is_pd is None:
_is_pd = is_pd
new_obj = wrap_broadcasted(
new_obj,
old_obj=aligned_objs2[k] if k not in special_keys else None,
axis=_axis,
is_pd=_is_pd,
new_index=new_index,
new_columns=new_columns,
ignore_ranges=ignore_ranges,
)
# Post-process array
_post_func = bco_instances[k].post_func
if _post_func is not None:
new_obj = _post_func(new_obj)
new_objs3[k] = new_obj
# Prepare outputs
return_objs = []
for k in all_keys:
if k not in none_keys:
if k in default_keys and keep_wrap_default:
return_objs.append(Default(new_objs3[k]))
else:
return_objs.append(new_objs3[k])
else:
if k in default_keys and keep_wrap_default:
return_objs.append(Default(None))
else:
return_objs.append(None)
if return_dict:
return_objs = dict(zip(all_keys, return_objs))
else:
return_objs = tuple(return_objs)
if len(return_objs) > 1 or return_dict:
if return_wrapper:
return return_objs, wrapper
return return_objs
if return_wrapper:
return return_objs[0], wrapper
return return_objs[0]
def broadcast_to(
arg1: tp.ArrayLike,
arg2: tp.Union[tp.ArrayLike, tp.ShapeLike, wrapping.ArrayWrapper],
to_pd: tp.Optional[bool] = None,
index_from: tp.Optional[IndexFromLike] = None,
columns_from: tp.Optional[IndexFromLike] = None,
**kwargs,
) -> tp.Any:
"""Broadcast `arg1` to `arg2`.
Argument `arg2` can be a shape, an instance of `vectorbtpro.base.wrapping.ArrayWrapper`,
or any array-like object.
Pass None to `index_from`/`columns_from` to use index/columns of the second argument.
Keyword arguments `**kwargs` are passed to `broadcast`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import broadcast_to
>>> a = np.array([1, 2, 3])
>>> sr = pd.Series([4, 5, 6], index=pd.Index(['x', 'y', 'z']), name='a')
>>> broadcast_to(a, sr)
x 1
y 2
z 3
Name: a, dtype: int64
>>> broadcast_to(sr, a)
array([4, 5, 6])
```
"""
if checks.is_int(arg2) or isinstance(arg2, tuple):
arg2 = to_tuple_shape(arg2)
if isinstance(arg2, tuple):
to_shape = arg2
elif isinstance(arg2, wrapping.ArrayWrapper):
to_pd = True
if index_from is None:
index_from = arg2.index
if columns_from is None:
columns_from = arg2.columns
to_shape = arg2.shape
else:
arg2 = to_any_array(arg2)
if to_pd is None:
to_pd = checks.is_pandas(arg2)
if to_pd:
# Take index and columns from arg2
if index_from is None:
index_from = indexes.get_index(arg2, 0)
if columns_from is None:
columns_from = indexes.get_index(arg2, 1)
to_shape = arg2.shape
return broadcast(
arg1,
to_shape=to_shape,
to_pd=to_pd,
index_from=index_from,
columns_from=columns_from,
**kwargs,
)
def broadcast_to_array_of(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> tp.Array:
"""Broadcast `arg1` to the shape `(1, *arg2.shape)`.
`arg1` must be either a scalar, a 1-dim array, or have 1 dimension more than `arg2`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import broadcast_to_array_of
>>> broadcast_to_array_of([0.1, 0.2], np.empty((2, 2)))
[[[0.1 0.1]
[0.1 0.1]]
[[0.2 0.2]
[0.2 0.2]]]
```
"""
arg1 = np.asarray(arg1)
arg2 = np.asarray(arg2)
if arg1.ndim == arg2.ndim + 1:
if arg1.shape[1:] == arg2.shape:
return arg1
# From here on arg1 can be only a 1-dim array
if arg1.ndim == 0:
arg1 = to_1d(arg1)
checks.assert_ndim(arg1, 1)
if arg2.ndim == 0:
return arg1
for i in range(arg2.ndim):
arg1 = np.expand_dims(arg1, axis=-1)
return np.tile(arg1, (1, *arg2.shape))
def broadcast_to_axis_of(
arg1: tp.ArrayLike,
arg2: tp.ArrayLike,
axis: int,
require_kwargs: tp.KwargsLike = None,
) -> tp.Array:
"""Broadcast `arg1` to an axis of `arg2`.
If `arg2` has less dimensions than requested, will broadcast `arg1` to a single number.
For other keyword arguments, see `broadcast`."""
if require_kwargs is None:
require_kwargs = {}
arg2 = to_any_array(arg2)
if arg2.ndim < axis + 1:
return broadcast_array_to(arg1, (1,))[0] # to a single number
arg1 = broadcast_array_to(arg1, (arg2.shape[axis],))
arg1 = np.require(arg1, **require_kwargs)
return arg1
def broadcast_combs(
*objs: tp.ArrayLike,
axis: int = 1,
comb_func: tp.Callable = itertools.product,
**broadcast_kwargs,
) -> tp.Any:
"""Align an axis of each array using a combinatoric function and broadcast their indexes.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import broadcast_combs
>>> df = pd.DataFrame([[1, 2, 3], [3, 4, 5]], columns=pd.Index(['a', 'b', 'c'], name='df_param'))
>>> df2 = pd.DataFrame([[6, 7], [8, 9]], columns=pd.Index(['d', 'e'], name='df2_param'))
>>> sr = pd.Series([10, 11], name='f')
>>> new_df, new_df2, new_sr = broadcast_combs((df, df2, sr))
>>> new_df
df_param a b c
df2_param d e d e d e
0 1 1 2 2 3 3
1 3 3 4 4 5 5
>>> new_df2
df_param a b c
df2_param d e d e d e
0 6 7 6 7 6 7
1 8 9 8 9 8 9
>>> new_sr
df_param a b c
df2_param d e d e d e
0 10 10 10 10 10 10
1 11 11 11 11 11 11
```
"""
if broadcast_kwargs is None:
broadcast_kwargs = {}
objs = list(objs)
if len(objs) < 2:
raise ValueError("At least two arguments are required")
for i in range(len(objs)):
obj = to_any_array(objs[i])
if axis == 1:
obj = to_2d(obj)
objs[i] = obj
indices = []
for obj in objs:
indices.append(np.arange(len(indexes.get_index(to_pd_array(obj), axis))))
new_indices = list(map(list, zip(*list(comb_func(*indices)))))
results = []
for i, obj in enumerate(objs):
if axis == 1:
if checks.is_pandas(obj):
results.append(obj.iloc[:, new_indices[i]])
else:
results.append(obj[:, new_indices[i]])
else:
if checks.is_pandas(obj):
results.append(obj.iloc[new_indices[i]])
else:
results.append(obj[new_indices[i]])
if axis == 1:
broadcast_kwargs = merge_dicts(dict(columns_from="stack"), broadcast_kwargs)
else:
broadcast_kwargs = merge_dicts(dict(index_from="stack"), broadcast_kwargs)
return broadcast(*results, **broadcast_kwargs)
def get_multiindex_series(obj: tp.SeriesFrame) -> tp.Series:
"""Get Series with a multi-index.
If DataFrame has been passed, must at maximum have one row or column."""
checks.assert_instance_of(obj, (pd.Series, pd.DataFrame))
if checks.is_frame(obj):
if obj.shape[0] == 1:
obj = obj.iloc[0, :]
elif obj.shape[1] == 1:
obj = obj.iloc[:, 0]
else:
raise ValueError("Supported are either Series or DataFrame with one column/row")
checks.assert_instance_of(obj.index, pd.MultiIndex)
return obj
def unstack_to_array(
obj: tp.SeriesFrame,
levels: tp.Optional[tp.MaybeLevelSequence] = None,
sort: bool = True,
return_indexes: bool = False,
) -> tp.Union[tp.Array, tp.Tuple[tp.Array, tp.List[tp.Index]]]:
"""Reshape `obj` based on its multi-index into a multi-dimensional array.
Use `levels` to specify what index levels to unstack and in which order.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import unstack_to_array
>>> index = pd.MultiIndex.from_arrays(
... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']])
>>> sr = pd.Series([1, 2, 3, 4], index=index)
>>> unstack_to_array(sr).shape
(2, 2, 4)
>>> unstack_to_array(sr)
[[[ 1. nan nan nan]
[nan 2. nan nan]]
[[nan nan 3. nan]
[nan nan nan 4.]]]
>>> unstack_to_array(sr, levels=(2, 0))
[[ 1. nan]
[ 2. nan]
[nan 3.]
[nan 4.]]
```
"""
sr = get_multiindex_series(obj)
if sr.index.duplicated().any():
raise ValueError("Index contains duplicate entries, cannot reshape")
new_index_list = []
value_indices_list = []
if levels is None:
levels = range(sr.index.nlevels)
if isinstance(levels, (int, str)):
levels = (levels,)
for level in levels:
level_values = indexes.select_levels(sr.index, level)
new_index = level_values.unique()
if sort:
new_index = new_index.sort_values()
new_index_list.append(new_index)
index_map = pd.Series(range(len(new_index)), index=new_index)
value_indices = index_map.loc[level_values]
value_indices_list.append(value_indices)
a = np.full(list(map(len, new_index_list)), np.nan)
a[tuple(zip(value_indices_list))] = sr.values
if return_indexes:
return a, new_index_list
return a
def make_symmetric(obj: tp.SeriesFrame, sort: bool = True) -> tp.Frame:
"""Make `obj` symmetric.
The index and columns of the resulting DataFrame will be identical.
Requires the index and columns to have the same number of levels.
Pass `sort=False` if index and columns should not be sorted, but concatenated
and get duplicates removed.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import make_symmetric
>>> df = pd.DataFrame([[1, 2], [3, 4]], index=['a', 'b'], columns=['c', 'd'])
>>> make_symmetric(df)
a b c d
a NaN NaN 1.0 2.0
b NaN NaN 3.0 4.0
c 1.0 3.0 NaN NaN
d 2.0 4.0 NaN NaN
```
"""
from vectorbtpro.base.merging import concat_arrays
checks.assert_instance_of(obj, (pd.Series, pd.DataFrame))
df = to_2d(obj)
if isinstance(df.index, pd.MultiIndex) or isinstance(df.columns, pd.MultiIndex):
checks.assert_instance_of(df.index, pd.MultiIndex)
checks.assert_instance_of(df.columns, pd.MultiIndex)
checks.assert_array_equal(df.index.nlevels, df.columns.nlevels)
names1, names2 = tuple(df.index.names), tuple(df.columns.names)
else:
names1, names2 = df.index.name, df.columns.name
if names1 == names2:
new_name = names1
else:
if isinstance(df.index, pd.MultiIndex):
new_name = tuple(zip(*[names1, names2]))
else:
new_name = (names1, names2)
if sort:
idx_vals = np.unique(concat_arrays((df.index, df.columns))).tolist()
else:
idx_vals = list(dict.fromkeys(concat_arrays((df.index, df.columns))))
df_index = df.index.copy()
df_columns = df.columns.copy()
if isinstance(df.index, pd.MultiIndex):
unique_index = pd.MultiIndex.from_tuples(idx_vals, names=new_name)
df_index.names = new_name
df_columns.names = new_name
else:
unique_index = pd.Index(idx_vals, name=new_name)
df_index.name = new_name
df_columns.name = new_name
df = df.copy(deep=False)
df.index = df_index
df.columns = df_columns
df_out_dtype = np.promote_types(df.values.dtype, np.min_scalar_type(np.nan))
df_out = pd.DataFrame(index=unique_index, columns=unique_index, dtype=df_out_dtype)
df_out.loc[:, :] = df
df_out[df_out.isnull()] = df.transpose()
return df_out
def unstack_to_df(
obj: tp.SeriesFrame,
index_levels: tp.Optional[tp.MaybeLevelSequence] = None,
column_levels: tp.Optional[tp.MaybeLevelSequence] = None,
symmetric: bool = False,
sort: bool = True,
) -> tp.Frame:
"""Reshape `obj` based on its multi-index into a DataFrame.
Use `index_levels` to specify what index levels will form new index, and `column_levels`
for new columns. Set `symmetric` to True to make DataFrame symmetric.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import unstack_to_df
>>> index = pd.MultiIndex.from_arrays(
... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']],
... names=['x', 'y', 'z'])
>>> sr = pd.Series([1, 2, 3, 4], index=index)
>>> unstack_to_df(sr, index_levels=(0, 1), column_levels=2)
z a b c d
x y
1 3 1.0 NaN NaN NaN
1 4 NaN 2.0 NaN NaN
2 3 NaN NaN 3.0 NaN
2 4 NaN NaN NaN 4.0
```
"""
sr = get_multiindex_series(obj)
if sr.index.nlevels > 2:
if index_levels is None:
raise ValueError("index_levels must be specified")
if column_levels is None:
raise ValueError("column_levels must be specified")
else:
if index_levels is None:
index_levels = 0
if column_levels is None:
column_levels = 1
unstacked, (new_index, new_columns) = unstack_to_array(
sr,
levels=(index_levels, column_levels),
sort=sort,
return_indexes=True,
)
df = pd.DataFrame(unstacked, index=new_index, columns=new_columns)
if symmetric:
return make_symmetric(df, sort=sort)
return df
</file>
<file path="base/wrapping.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for wrapping NumPy arrays into Series/DataFrames."""
import numpy as np
import pandas as pd
from pandas.core.groupby import GroupBy as PandasGroupBy
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import indexes, reshaping
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.base.indexes import stack_indexes, concat_indexes, IndexApplier
from vectorbtpro.base.indexing import IndexingError, ExtPandasIndexer, index_dict, IdxSetter, IdxSetterFactory, IdxDict
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.array_ import is_range, cast_to_min_precision, cast_to_max_precision
from vectorbtpro.utils.attr_ import AttrResolverMixin, AttrResolverMixinT
from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer
from vectorbtpro.utils.config import Configured, merge_dicts, resolve_dict
from vectorbtpro.utils.decorators import hybrid_method, cached_method, cached_property
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.params import ItemParamable
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.warnings_ import warn
if tp.TYPE_CHECKING:
from vectorbtpro.base.accessors import BaseIDXAccessor as BaseIDXAccessorT
from vectorbtpro.generic.splitting.base import Splitter as SplitterT
else:
BaseIDXAccessorT = "BaseIDXAccessor"
SplitterT = "Splitter"
__all__ = [
"ArrayWrapper",
"Wrapping",
]
HasWrapperT = tp.TypeVar("HasWrapperT", bound="HasWrapper")
class HasWrapper(ExtPandasIndexer, ItemParamable):
"""Abstract class that manages a wrapper."""
@property
def unwrapped(self) -> tp.Any:
"""Unwrapped object."""
raise NotImplementedError
@hybrid_method
def should_wrap(cls_or_self) -> bool:
"""Whether to wrap where applicable."""
return True
@property
def wrapper(self) -> "ArrayWrapper":
"""Array wrapper of the type `ArrayWrapper`."""
raise NotImplementedError
@property
def column_only_select(self) -> bool:
"""Whether to perform indexing on columns only."""
raise NotImplementedError
@property
def range_only_select(self) -> bool:
"""Whether to perform indexing on rows using slices only."""
raise NotImplementedError
@property
def group_select(self) -> bool:
"""Whether to allow indexing on groups."""
raise NotImplementedError
def regroup(self: HasWrapperT, group_by: tp.GroupByLike, **kwargs) -> HasWrapperT:
"""Regroup this instance."""
raise NotImplementedError
def ungroup(self: HasWrapperT, **kwargs) -> HasWrapperT:
"""Ungroup this instance."""
return self.regroup(False, **kwargs)
# ############# Selection ############# #
def select_col(
self: HasWrapperT,
column: tp.Any = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> HasWrapperT:
"""Select one column/group.
`column` can be a label-based position as well as an integer position (if label fails)."""
_self = self.regroup(group_by, **kwargs)
def _check_out_dim(out: HasWrapperT) -> HasWrapperT:
if out.wrapper.get_ndim() == 2:
if out.wrapper.get_shape_2d()[1] == 1:
if out.column_only_select:
return out.iloc[0]
return out.iloc[:, 0]
if _self.wrapper.grouper.is_grouped():
raise TypeError("Could not select one group: multiple groups returned")
else:
raise TypeError("Could not select one column: multiple columns returned")
return out
if column is None:
if _self.wrapper.get_ndim() == 2 and _self.wrapper.get_shape_2d()[1] == 1:
column = 0
if column is not None:
if _self.wrapper.grouper.is_grouped():
if _self.wrapper.grouped_ndim == 1:
raise TypeError("This instance already contains one group of data")
if column not in _self.wrapper.get_columns():
if isinstance(column, int):
if _self.column_only_select:
return _check_out_dim(_self.iloc[column])
return _check_out_dim(_self.iloc[:, column])
raise KeyError(f"Group '{column}' not found")
else:
if _self.wrapper.ndim == 1:
raise TypeError("This instance already contains one column of data")
if column not in _self.wrapper.columns:
if isinstance(column, int):
if _self.column_only_select:
return _check_out_dim(_self.iloc[column])
return _check_out_dim(_self.iloc[:, column])
raise KeyError(f"Column '{column}' not found")
return _check_out_dim(_self[column])
if _self.wrapper.grouper.is_grouped():
if _self.wrapper.grouped_ndim == 1:
return _self
raise TypeError("Only one group is allowed. Use indexing or column argument.")
if _self.wrapper.ndim == 1:
return _self
raise TypeError("Only one column is allowed. Use indexing or column argument.")
@hybrid_method
def select_col_from_obj(
cls_or_self,
obj: tp.Optional[tp.SeriesFrame],
column: tp.Any = None,
obj_ungrouped: bool = False,
group_by: tp.GroupByLike = None,
wrapper: tp.Optional["ArrayWrapper"] = None,
**kwargs,
) -> tp.MaybeSeries:
"""Select one column/group from a Pandas object.
`column` can be a label-based position as well as an integer position (if label fails)."""
if obj is None:
return None
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
_wrapper = wrapper.regroup(group_by, **kwargs)
def _check_out_dim(out: tp.SeriesFrame, from_df: bool) -> tp.Series:
bad_shape = False
if from_df and isinstance(out, pd.DataFrame):
if len(out.columns) == 1:
return out.iloc[:, 0]
bad_shape = True
if not from_df and isinstance(out, pd.Series):
if len(out) == 1:
return out.iloc[0]
bad_shape = True
if bad_shape:
if _wrapper.grouper.is_grouped():
raise TypeError("Could not select one group: multiple groups returned")
else:
raise TypeError("Could not select one column: multiple columns returned")
return out
if column is None:
if _wrapper.get_ndim() == 2 and _wrapper.get_shape_2d()[1] == 1:
column = 0
if column is not None:
if _wrapper.grouper.is_grouped():
if _wrapper.grouped_ndim == 1:
raise TypeError("This instance already contains one group of data")
if obj_ungrouped:
mask = _wrapper.grouper.group_by == column
if not mask.any():
raise KeyError(f"Group '{column}' not found")
if isinstance(obj, pd.DataFrame):
return obj.loc[:, mask]
return obj.loc[mask]
else:
if column not in _wrapper.get_columns():
if isinstance(column, int):
if isinstance(obj, pd.DataFrame):
return _check_out_dim(obj.iloc[:, column], True)
return _check_out_dim(obj.iloc[column], False)
raise KeyError(f"Group '{column}' not found")
else:
if _wrapper.ndim == 1:
raise TypeError("This instance already contains one column of data")
if column not in _wrapper.columns:
if isinstance(column, int):
if isinstance(obj, pd.DataFrame):
return _check_out_dim(obj.iloc[:, column], True)
return _check_out_dim(obj.iloc[column], False)
raise KeyError(f"Column '{column}' not found")
if isinstance(obj, pd.DataFrame):
return _check_out_dim(obj[column], True)
return _check_out_dim(obj[column], False)
if not _wrapper.grouper.is_grouped():
if _wrapper.ndim == 1:
return obj
raise TypeError("Only one column is allowed. Use indexing or column argument.")
if _wrapper.grouped_ndim == 1:
return obj
raise TypeError("Only one group is allowed. Use indexing or column argument.")
# ############# Splitting ############# #
def split(
self,
*args,
splitter_cls: tp.Optional[tp.Type[SplitterT]] = None,
wrap: tp.Optional[bool] = None,
**kwargs,
) -> tp.Any:
"""Split this instance.
Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_take`."""
from vectorbtpro.generic.splitting.base import Splitter
if splitter_cls is None:
splitter_cls = Splitter
if wrap is None:
wrap = self.should_wrap()
wrapped_self = self if wrap else self.unwrapped
return splitter_cls.split_and_take(self.wrapper.index, wrapped_self, *args, **kwargs)
def split_apply(
self,
apply_func: tp.Union[str, tp.Callable],
*args,
splitter_cls: tp.Optional[tp.Type[SplitterT]] = None,
wrap: tp.Optional[bool] = None,
**kwargs,
) -> tp.Any:
"""Split this instance and apply a function to each split.
Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`."""
from vectorbtpro.generic.splitting.base import Splitter, Takeable
if isinstance(apply_func, str):
apply_func = getattr(type(self), apply_func)
if splitter_cls is None:
splitter_cls = Splitter
if wrap is None:
wrap = self.should_wrap()
wrapped_self = self if wrap else self.unwrapped
return splitter_cls.split_and_apply(self.wrapper.index, apply_func, Takeable(wrapped_self), *args, **kwargs)
# ############# Chunking ############# #
def chunk(
self: HasWrapperT,
axis: tp.Optional[int] = None,
min_size: tp.Optional[int] = None,
n_chunks: tp.Union[None, int, str] = None,
chunk_len: tp.Union[None, int, str] = None,
chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None,
select: bool = False,
wrap: tp.Optional[bool] = None,
return_chunk_meta: bool = False,
) -> tp.Iterator[tp.Union[HasWrapperT, tp.Tuple[ChunkMeta, HasWrapperT]]]:
"""Chunk this instance.
If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise.
For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`."""
if axis is None:
axis = 0 if self.wrapper.ndim == 1 else 1
if self.wrapper.ndim == 1 and axis == 1:
raise TypeError("Axis 1 is not supported for one dimension")
checks.assert_in(axis, (0, 1))
size = self.wrapper.shape_2d[axis]
if wrap is None:
wrap = self.should_wrap()
wrapped_self = self if wrap else self.unwrapped
if chunk_meta is None:
chunk_meta = iter_chunk_meta(
size=size,
min_size=min_size,
n_chunks=n_chunks,
chunk_len=chunk_len,
)
for _chunk_meta in chunk_meta:
if select:
array_taker = ArraySelector(axis=axis)
else:
array_taker = ArraySlicer(axis=axis)
if return_chunk_meta:
yield _chunk_meta, array_taker.take(wrapped_self, _chunk_meta)
else:
yield array_taker.take(wrapped_self, _chunk_meta)
def chunk_apply(
self: HasWrapperT,
apply_func: tp.Union[str, tp.Callable],
*args,
chunk_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MergeableResults:
"""Chunk this instance and apply a function to each chunk.
If `apply_func` is a string, becomes the method name.
For arguments related to chunking, see `Wrapping.chunk`."""
if isinstance(apply_func, str):
apply_func = getattr(type(self), apply_func)
if chunk_kwargs is None:
chunk_arg_names = set(get_func_arg_names(self.chunk))
chunk_kwargs = {}
for k in list(kwargs.keys()):
if k in chunk_arg_names:
chunk_kwargs[k] = kwargs.pop(k)
if execute_kwargs is None:
execute_kwargs = {}
chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs)
tasks = []
keys = []
for _chunk_meta, chunk in chunks:
tasks.append(Task(apply_func, chunk, *args, **kwargs))
keys.append(get_chunk_meta_key(_chunk_meta))
keys = pd.Index(keys, name="chunk_indices")
return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs)
# ############# Iteration ############# #
def get_item_keys(self, group_by: tp.GroupByLike = None) -> tp.Index:
"""Get keys for `Wrapping.items`."""
_self = self.regroup(group_by=group_by)
if _self.group_select and _self.wrapper.grouper.is_grouped():
return _self.wrapper.get_columns()
return _self.wrapper.columns
def items(
self,
group_by: tp.GroupByLike = None,
apply_group_by: bool = False,
keep_2d: bool = False,
key_as_index: bool = False,
wrap: tp.Optional[bool] = None,
) -> tp.Items:
"""Iterate over columns or groups (if grouped and `Wrapping.group_select` is True).
If `apply_group_by` is False, `group_by` becomes a grouping instruction for the iteration,
not for the final object. In this case, will raise an error if the instance is grouped
and that grouping must be changed."""
if wrap is None:
wrap = self.should_wrap()
def _resolve_v(self):
return self if wrap else self.unwrapped
if group_by is None or apply_group_by:
_self = self.regroup(group_by=group_by)
if _self.group_select and _self.wrapper.grouper.is_grouped():
columns = _self.wrapper.get_columns()
ndim = _self.wrapper.get_ndim()
else:
columns = _self.wrapper.columns
ndim = _self.wrapper.ndim
if ndim == 1:
if key_as_index:
yield columns, _resolve_v(_self)
else:
yield columns[0], _resolve_v(_self)
else:
for i in range(len(columns)):
if key_as_index:
key = columns[[i]]
else:
key = columns[i]
if _self.column_only_select:
if keep_2d:
yield key, _resolve_v(_self.iloc[i : i + 1])
else:
yield key, _resolve_v(_self.iloc[i])
else:
if keep_2d:
yield key, _resolve_v(_self.iloc[:, i : i + 1])
else:
yield key, _resolve_v(_self.iloc[:, i])
else:
if self.group_select and self.wrapper.grouper.is_grouped():
raise ValueError("Cannot change grouping")
wrapper = self.wrapper.regroup(group_by=group_by)
if wrapper.get_ndim() == 1:
if key_as_index:
yield wrapper.get_columns(), _resolve_v(self)
else:
yield wrapper.get_columns()[0], _resolve_v(self)
else:
for group, group_idxs in wrapper.grouper.iter_groups(key_as_index=key_as_index):
if self.column_only_select:
if keep_2d or len(group_idxs) > 1:
yield group, _resolve_v(self.iloc[group_idxs])
else:
yield group, _resolve_v(self.iloc[group_idxs[0]])
else:
if keep_2d or len(group_idxs) > 1:
yield group, _resolve_v(self.iloc[:, group_idxs])
else:
yield group, _resolve_v(self.iloc[:, group_idxs[0]])
ArrayWrapperT = tp.TypeVar("ArrayWrapperT", bound="ArrayWrapper")
class ArrayWrapper(Configured, HasWrapper, IndexApplier):
"""Class that stores index, columns, and shape metadata for wrapping NumPy arrays.
Tightly integrated with `vectorbtpro.base.grouping.base.Grouper` for grouping columns.
If the underlying object is a Series, pass `[sr.name]` as `columns`.
`**kwargs` are passed to `vectorbtpro.base.grouping.base.Grouper`.
!!! note
This class is meant to be immutable. To change any attribute, use `ArrayWrapper.replace`.
Use methods that begin with `get_` to get group-aware results."""
@classmethod
def from_obj(cls: tp.Type[ArrayWrapperT], obj: tp.ArrayLike, **kwargs) -> ArrayWrapperT:
"""Derive metadata from an object."""
from vectorbtpro.base.reshaping import to_pd_array
from vectorbtpro.data.base import Data
if isinstance(obj, Data):
obj = obj.symbol_wrapper
if isinstance(obj, Wrapping):
obj = obj.wrapper
if isinstance(obj, ArrayWrapper):
return obj.replace(**kwargs)
pd_obj = to_pd_array(obj)
index = indexes.get_index(pd_obj, 0)
columns = indexes.get_index(pd_obj, 1)
ndim = pd_obj.ndim
kwargs.pop("index", None)
kwargs.pop("columns", None)
kwargs.pop("ndim", None)
return cls(index, columns, ndim, **kwargs)
@classmethod
def from_shape(
cls: tp.Type[ArrayWrapperT],
shape: tp.ShapeLike,
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
ndim: tp.Optional[int] = None,
*args,
**kwargs,
) -> ArrayWrapperT:
"""Derive metadata from shape."""
shape = reshaping.to_tuple_shape(shape)
if index is None:
index = pd.RangeIndex(stop=shape[0])
if columns is None:
columns = pd.RangeIndex(stop=shape[1] if len(shape) > 1 else 1)
if ndim is None:
ndim = len(shape)
return cls(index, columns, ndim, *args, **kwargs)
@staticmethod
def extract_init_kwargs(**kwargs) -> tp.Tuple[tp.Kwargs, tp.Kwargs]:
"""Extract keyword arguments that can be passed to `ArrayWrapper` or `Grouper`."""
wrapper_arg_names = get_func_arg_names(ArrayWrapper.__init__)
grouper_arg_names = get_func_arg_names(Grouper.__init__)
init_kwargs = dict()
for k in list(kwargs.keys()):
if k in wrapper_arg_names or k in grouper_arg_names:
init_kwargs[k] = kwargs.pop(k)
return init_kwargs, kwargs
@classmethod
def resolve_stack_kwargs(cls, *wrappers: tp.MaybeTuple[ArrayWrapperT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `ArrayWrapper` after stacking."""
if len(wrappers) == 1:
wrappers = wrappers[0]
wrappers = list(wrappers)
common_keys = set()
for wrapper in wrappers:
common_keys = common_keys.union(set(wrapper.config.keys()))
if "grouper" not in kwargs:
common_keys = common_keys.union(set(wrapper.grouper.config.keys()))
common_keys.remove("grouper")
init_wrapper = wrappers[0]
for i in range(1, len(wrappers)):
wrapper = wrappers[i]
for k in common_keys:
if k not in kwargs:
same_k = True
try:
if k in wrapper.config:
if not checks.is_deep_equal(init_wrapper.config[k], wrapper.config[k]):
same_k = False
elif "grouper" not in kwargs and k in wrapper.grouper.config:
if not checks.is_deep_equal(init_wrapper.grouper.config[k], wrapper.grouper.config[k]):
same_k = False
else:
same_k = False
except KeyError as e:
same_k = False
if not same_k:
raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.")
for k in common_keys:
if k not in kwargs:
if k in init_wrapper.config:
kwargs[k] = init_wrapper.config[k]
elif "grouper" not in kwargs and k in init_wrapper.grouper.config:
kwargs[k] = init_wrapper.grouper.config[k]
else:
raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.")
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[ArrayWrapperT],
*wrappers: tp.MaybeTuple[ArrayWrapperT],
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
group_by: tp.GroupByLike = None,
stack_columns: bool = True,
index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append",
keys: tp.Optional[tp.IndexLike] = None,
clean_index_kwargs: tp.KwargsLike = None,
verify_integrity: bool = True,
**kwargs,
) -> ArrayWrapperT:
"""Stack multiple `ArrayWrapper` instances along rows.
Concatenates indexes using `vectorbtpro.base.indexes.concat_indexes`.
Frequency must be the same across all indexes. A custom frequency can be provided via `freq`.
If column levels in some instances differ, they will be stacked upon each other.
Custom columns can be provided via `columns`.
If `group_by` is None, all instances must be either grouped or not, and they must
contain the same group values and labels.
All instances must contain the same keys and values in their configs and configs of their
grouper instances, apart from those arguments provided explicitly via `kwargs`."""
if not isinstance(cls_or_self, type):
wrappers = (cls_or_self, *wrappers)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(wrappers) == 1:
wrappers = wrappers[0]
wrappers = list(wrappers)
for wrapper in wrappers:
if not checks.is_instance_of(wrapper, ArrayWrapper):
raise TypeError("Each object to be merged must be an instance of ArrayWrapper")
if index is None:
index = concat_indexes(
[wrapper.index for wrapper in wrappers],
index_concat_method=index_concat_method,
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=0,
)
elif not isinstance(index, pd.Index):
index = pd.Index(index)
kwargs["index"] = index
if freq is None:
new_freq = None
for wrapper in wrappers:
if new_freq is None:
new_freq = wrapper.freq
else:
if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq:
raise ValueError("Objects to be merged must have the same frequency")
freq = new_freq
kwargs["freq"] = freq
if columns is None:
new_columns = None
for wrapper in wrappers:
if new_columns is None:
new_columns = wrapper.columns
else:
if not checks.is_index_equal(new_columns, wrapper.columns):
if not stack_columns:
raise ValueError("Objects to be merged must have the same columns")
new_columns = stack_indexes(
(new_columns, wrapper.columns),
**resolve_dict(clean_index_kwargs),
)
columns = new_columns
elif not isinstance(columns, pd.Index):
columns = pd.Index(columns)
kwargs["columns"] = columns
if "grouper" in kwargs:
if not checks.is_index_equal(columns, kwargs["grouper"].index):
raise ValueError("Columns and grouper index must match")
if group_by is not None:
kwargs["group_by"] = group_by
else:
if group_by is None:
grouped = None
for wrapper in wrappers:
wrapper_grouped = wrapper.grouper.is_grouped()
if grouped is None:
grouped = wrapper_grouped
else:
if grouped is not wrapper_grouped:
raise ValueError("Objects to be merged must be either grouped or not")
if grouped:
new_group_by = None
for wrapper in wrappers:
wrapper_groups, wrapper_grouped_index = wrapper.grouper.get_groups_and_index()
wrapper_group_by = wrapper_grouped_index[wrapper_groups]
if new_group_by is None:
new_group_by = wrapper_group_by
else:
if not checks.is_index_equal(new_group_by, wrapper_group_by):
raise ValueError("Objects to be merged must have the same groups")
group_by = new_group_by
else:
group_by = False
kwargs["group_by"] = group_by
if "ndim" not in kwargs:
ndim = None
for wrapper in wrappers:
if ndim is None or wrapper.ndim > 1:
ndim = wrapper.ndim
kwargs["ndim"] = ndim
return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs))
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[ArrayWrapperT],
*wrappers: tp.MaybeTuple[ArrayWrapperT],
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
group_by: tp.GroupByLike = None,
union_index: bool = True,
col_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append",
group_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = ("append", "factorize_each"),
keys: tp.Optional[tp.IndexLike] = None,
clean_index_kwargs: tp.KwargsLike = None,
verify_integrity: bool = True,
**kwargs,
) -> ArrayWrapperT:
"""Stack multiple `ArrayWrapper` instances along columns.
If indexes are the same in each wrapper index, will use that index. If indexes differ and
`union_index` is True, they will be merged into a single one by the set union operation.
Otherwise, an error will be raised. The merged index must have no duplicates or mixed data,
and must be monotonically increasing. A custom index can be provided via `index`.
Frequency must be the same across all indexes. A custom frequency can be provided via `freq`.
Concatenates columns and groups using `vectorbtpro.base.indexes.concat_indexes`.
If any of the instances has `column_only_select` being enabled, the final wrapper will also enable it.
If any of the instances has `group_select` or other grouping-related flags being disabled, the final
wrapper will also disable them.
All instances must contain the same keys and values in their configs and configs of their
grouper instances, apart from those arguments provided explicitly via `kwargs`."""
if not isinstance(cls_or_self, type):
wrappers = (cls_or_self, *wrappers)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(wrappers) == 1:
wrappers = wrappers[0]
wrappers = list(wrappers)
for wrapper in wrappers:
if not checks.is_instance_of(wrapper, ArrayWrapper):
raise TypeError("Each object to be merged must be an instance of ArrayWrapper")
for wrapper in wrappers:
if wrapper.index.has_duplicates:
raise ValueError("Index of some objects to be merged contains duplicates")
if index is None:
new_index = None
for wrapper in wrappers:
if new_index is None:
new_index = wrapper.index
else:
if not checks.is_index_equal(new_index, wrapper.index):
if not union_index:
raise ValueError(
"Objects to be merged must have the same index. "
"Use union_index=True to merge index as well."
)
else:
if new_index.dtype != wrapper.index.dtype:
raise ValueError("Indexes to be merged must have the same data type")
new_index = new_index.union(wrapper.index)
if not new_index.is_monotonic_increasing:
raise ValueError("Merged index must be monotonically increasing")
index = new_index
elif not isinstance(index, pd.Index):
index = pd.Index(index)
kwargs["index"] = index
if freq is None:
new_freq = None
for wrapper in wrappers:
if new_freq is None:
new_freq = wrapper.freq
else:
if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq:
raise ValueError("Objects to be merged must have the same frequency")
freq = new_freq
kwargs["freq"] = freq
if columns is None:
columns = concat_indexes(
[wrapper.columns for wrapper in wrappers],
index_concat_method=col_concat_method,
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=1,
)
elif not isinstance(columns, pd.Index):
columns = pd.Index(columns)
kwargs["columns"] = columns
if "grouper" in kwargs:
if not checks.is_index_equal(columns, kwargs["grouper"].index):
raise ValueError("Columns and grouper index must match")
if group_by is not None:
kwargs["group_by"] = group_by
else:
if group_by is None:
any_grouped = False
for wrapper in wrappers:
if wrapper.grouper.is_grouped():
any_grouped = True
break
if any_grouped:
group_by = concat_indexes(
[wrapper.grouper.get_stretched_index() for wrapper in wrappers],
index_concat_method=group_concat_method,
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=2,
)
else:
group_by = False
kwargs["group_by"] = group_by
if "ndim" not in kwargs:
kwargs["ndim"] = 2
if "grouped_ndim" not in kwargs:
kwargs["grouped_ndim"] = None
if "column_only_select" not in kwargs:
column_only_select = None
for wrapper in wrappers:
if column_only_select is None or wrapper.column_only_select:
column_only_select = wrapper.column_only_select
kwargs["column_only_select"] = column_only_select
if "range_only_select" not in kwargs:
range_only_select = None
for wrapper in wrappers:
if range_only_select is None or wrapper.range_only_select:
range_only_select = wrapper.range_only_select
kwargs["range_only_select"] = range_only_select
if "group_select" not in kwargs:
group_select = None
for wrapper in wrappers:
if group_select is None or not wrapper.group_select:
group_select = wrapper.group_select
kwargs["group_select"] = group_select
if "grouper" not in kwargs:
if "allow_enable" not in kwargs:
allow_enable = None
for wrapper in wrappers:
if allow_enable is None or not wrapper.grouper.allow_enable:
allow_enable = wrapper.grouper.allow_enable
kwargs["allow_enable"] = allow_enable
if "allow_disable" not in kwargs:
allow_disable = None
for wrapper in wrappers:
if allow_disable is None or not wrapper.grouper.allow_disable:
allow_disable = wrapper.grouper.allow_disable
kwargs["allow_disable"] = allow_disable
if "allow_modify" not in kwargs:
allow_modify = None
for wrapper in wrappers:
if allow_modify is None or not wrapper.grouper.allow_modify:
allow_modify = wrapper.grouper.allow_modify
kwargs["allow_modify"] = allow_modify
return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs))
def __init__(
self,
index: tp.IndexLike,
columns: tp.Optional[tp.IndexLike] = None,
ndim: tp.Optional[int] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
parse_index: tp.Optional[bool] = None,
column_only_select: tp.Optional[bool] = None,
range_only_select: tp.Optional[bool] = None,
group_select: tp.Optional[bool] = None,
grouped_ndim: tp.Optional[int] = None,
grouper: tp.Optional[Grouper] = None,
**kwargs,
) -> None:
checks.assert_not_none(index, arg_name="index")
index = dt.prepare_dt_index(index, parse_index=parse_index)
if columns is None:
columns = [None]
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if ndim is None:
if len(columns) == 1 and not isinstance(columns, pd.MultiIndex):
ndim = 1
else:
ndim = 2
else:
if len(columns) > 1:
ndim = 2
grouper_arg_names = get_func_arg_names(Grouper.__init__)
grouper_kwargs = dict()
for k in list(kwargs.keys()):
if k in grouper_arg_names:
grouper_kwargs[k] = kwargs.pop(k)
if grouper is None:
grouper = Grouper(columns, **grouper_kwargs)
elif not checks.is_index_equal(columns, grouper.index) or len(grouper_kwargs) > 0:
grouper = grouper.replace(index=columns, **grouper_kwargs)
HasWrapper.__init__(self)
Configured.__init__(
self,
index=index,
columns=columns,
ndim=ndim,
freq=freq,
parse_index=parse_index,
column_only_select=column_only_select,
range_only_select=range_only_select,
group_select=group_select,
grouped_ndim=grouped_ndim,
grouper=grouper,
**kwargs,
)
self._index = index
self._columns = columns
self._ndim = ndim
self._freq = freq
self._parse_index = parse_index
self._column_only_select = column_only_select
self._range_only_select = range_only_select
self._group_select = group_select
self._grouper = grouper
self._grouped_ndim = grouped_ndim
def indexing_func_meta(
self: ArrayWrapperT,
pd_indexing_func: tp.PandasIndexingFunc,
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
column_only_select: tp.Optional[bool] = None,
range_only_select: tp.Optional[bool] = None,
group_select: tp.Optional[bool] = None,
return_slices: bool = True,
return_none_slices: bool = True,
return_scalars: bool = True,
group_by: tp.GroupByLike = None,
wrapper_kwargs: tp.KwargsLike = None,
) -> dict:
"""Perform indexing on `ArrayWrapper` and also return metadata.
Takes into account column grouping.
Flipping rows and columns is not allowed. If one row is selected, the result will still be
a Series when indexing a Series and a DataFrame when indexing a DataFrame.
Set `column_only_select` to True to index the array wrapper as a Series of columns/groups.
This way, selection of index (axis 0) can be avoided. Set `range_only_select` to True to
allow selection of rows only using slices. Set `group_select` to True to allow selection of groups.
Otherwise, indexing is performed on columns, even if grouping is enabled. Takes effect only if
grouping is enabled.
Returns the new array wrapper, row indices, column indices, and group indices.
If `return_slices` is True (default), indices will be returned as a slice if they were
identified as a range. If `return_none_slices` is True (default), indices will be returned as a slice
`(None, None, None)` if the axis hasn't been changed.
!!! note
If `column_only_select` is True, make sure to index the array wrapper
as a Series of columns rather than a DataFrame. For example, the operation
`.iloc[:, :2]` should become `.iloc[:2]`. Operations are not allowed if the
object is already a Series and thus has only one column/group."""
if column_only_select is None:
column_only_select = self.column_only_select
if range_only_select is None:
range_only_select = self.range_only_select
if group_select is None:
group_select = self.group_select
if wrapper_kwargs is None:
wrapper_kwargs = {}
_self = self.regroup(group_by)
group_select = group_select and _self.grouper.is_grouped()
if index is None:
index = _self.index
if not isinstance(index, pd.Index):
index = pd.Index(index)
if columns is None:
if group_select:
columns = _self.get_columns()
else:
columns = _self.columns
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if group_select:
# Groups as columns
i_wrapper = ArrayWrapper(index, columns, _self.get_ndim())
else:
# Columns as columns
i_wrapper = ArrayWrapper(index, columns, _self.ndim)
n_rows = len(index)
n_cols = len(columns)
def _resolve_arr(arr, n):
if checks.is_np_array(arr) and is_range(arr):
if arr[0] == 0 and arr[-1] == n - 1:
if return_none_slices:
return slice(None, None, None), False
return arr, False
if return_slices:
return slice(arr[0], arr[-1] + 1, None), True
return arr, True
if isinstance(arr, np.integer):
arr = arr.item()
columns_changed = True
if isinstance(arr, int):
if arr == 0 and n == 1:
columns_changed = False
if not return_scalars:
arr = np.array([arr])
return arr, columns_changed
if column_only_select:
if i_wrapper.ndim == 1:
raise IndexingError("Columns only: This instance already contains one column of data")
try:
col_mapper = pd_indexing_func(i_wrapper.wrap_reduced(np.arange(n_cols), columns=columns))
except pd.core.indexing.IndexingError as e:
warn("Columns only: Make sure to treat this instance as a Series of columns rather than a DataFrame")
raise e
if checks.is_series(col_mapper):
new_columns = col_mapper.index
col_idxs = col_mapper.values
new_ndim = 2
else:
new_columns = columns[[col_mapper]]
col_idxs = col_mapper
new_ndim = 1
new_index = index
row_idxs = np.arange(len(index))
else:
init_row_mapper_values = reshaping.broadcast_array_to(np.arange(n_rows)[:, None], (n_rows, n_cols))
init_row_mapper = i_wrapper.wrap(init_row_mapper_values, index=index, columns=columns)
row_mapper = pd_indexing_func(init_row_mapper)
if i_wrapper.ndim == 1:
if not checks.is_series(row_mapper):
row_idxs = np.array([row_mapper])
new_index = index[row_idxs]
else:
row_idxs = row_mapper.values
new_index = indexes.get_index(row_mapper, 0)
col_idxs = 0
new_columns = columns
new_ndim = 1
else:
init_col_mapper_values = reshaping.broadcast_array_to(np.arange(n_cols)[None], (n_rows, n_cols))
init_col_mapper = i_wrapper.wrap(init_col_mapper_values, index=index, columns=columns)
col_mapper = pd_indexing_func(init_col_mapper)
if checks.is_frame(col_mapper):
# Multiple rows and columns selected
row_idxs = row_mapper.values[:, 0]
col_idxs = col_mapper.values[0]
new_index = indexes.get_index(row_mapper, 0)
new_columns = indexes.get_index(col_mapper, 1)
new_ndim = 2
elif checks.is_series(col_mapper):
multi_index = isinstance(index, pd.MultiIndex)
multi_columns = isinstance(columns, pd.MultiIndex)
multi_name = isinstance(col_mapper.name, tuple)
if multi_index and multi_name and col_mapper.name in index:
one_row = True
elif not multi_index and not multi_name and col_mapper.name in index:
one_row = True
else:
one_row = False
if multi_columns and multi_name and col_mapper.name in columns:
one_col = True
elif not multi_columns and not multi_name and col_mapper.name in columns:
one_col = True
else:
one_col = False
if (one_row and one_col) or (not one_row and not one_col):
one_row = np.all(row_mapper.values == row_mapper.values.item(0))
one_col = np.all(col_mapper.values == col_mapper.values.item(0))
if (one_row and one_col) or (not one_row and not one_col):
raise IndexingError("Could not parse indexing operation")
if one_row:
# One row selected
row_idxs = row_mapper.values[[0]]
col_idxs = col_mapper.values
new_index = index[row_idxs]
new_columns = indexes.get_index(col_mapper, 0)
new_ndim = 2
else:
# One column selected
row_idxs = row_mapper.values
col_idxs = col_mapper.values[0]
new_index = indexes.get_index(row_mapper, 0)
new_columns = columns[[col_idxs]]
new_ndim = 1
else:
# One row and column selected
row_idxs = np.array([row_mapper])
col_idxs = col_mapper
new_index = index[row_idxs]
new_columns = columns[[col_idxs]]
new_ndim = 1
if _self.grouper.is_grouped():
# Grouping enabled
if np.asarray(row_idxs).ndim == 0:
raise IndexingError("Flipping index and columns is not allowed")
if group_select:
# Selection based on groups
# Get indices of columns corresponding to selected groups
group_idxs = col_idxs
col_idxs, new_groups = _self.grouper.select_groups(group_idxs)
ungrouped_columns = _self.columns[col_idxs]
if new_ndim == 1 and len(ungrouped_columns) == 1:
ungrouped_ndim = 1
col_idxs = col_idxs[0]
else:
ungrouped_ndim = 2
row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0])
if range_only_select and rows_changed:
if not isinstance(row_idxs, slice):
raise ValueError("Rows can be selected only by slicing")
if row_idxs.step not in (1, None):
raise ValueError("Slice for selecting rows must have a step of 1 or None")
col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1])
group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1])
return dict(
new_wrapper=_self.replace(
**merge_dicts(
dict(
index=new_index,
columns=ungrouped_columns,
ndim=ungrouped_ndim,
grouped_ndim=new_ndim,
group_by=new_columns[new_groups],
),
wrapper_kwargs,
)
),
row_idxs=row_idxs,
rows_changed=rows_changed,
col_idxs=col_idxs,
columns_changed=columns_changed,
group_idxs=group_idxs,
groups_changed=groups_changed,
)
# Selection based on columns
group_idxs = _self.grouper.get_groups()[col_idxs]
new_group_by = _self.grouper.group_by[reshaping.to_1d_array(col_idxs)]
row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0])
if range_only_select and rows_changed:
if not isinstance(row_idxs, slice):
raise ValueError("Rows can be selected only by slicing")
if row_idxs.step not in (1, None):
raise ValueError("Slice for selecting rows must have a step of 1 or None")
col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1])
group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1])
return dict(
new_wrapper=_self.replace(
**merge_dicts(
dict(
index=new_index,
columns=new_columns,
ndim=new_ndim,
grouped_ndim=None,
group_by=new_group_by,
),
wrapper_kwargs,
)
),
row_idxs=row_idxs,
rows_changed=rows_changed,
col_idxs=col_idxs,
columns_changed=columns_changed,
group_idxs=group_idxs,
groups_changed=groups_changed,
)
# Grouping disabled
row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0])
if range_only_select and rows_changed:
if not isinstance(row_idxs, slice):
raise ValueError("Rows can be selected only by slicing")
if row_idxs.step not in (1, None):
raise ValueError("Slice for selecting rows must have a step of 1 or None")
col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1])
return dict(
new_wrapper=_self.replace(
**merge_dicts(
dict(
index=new_index,
columns=new_columns,
ndim=new_ndim,
grouped_ndim=None,
group_by=None,
),
wrapper_kwargs,
)
),
row_idxs=row_idxs,
rows_changed=rows_changed,
col_idxs=col_idxs,
columns_changed=columns_changed,
group_idxs=col_idxs,
groups_changed=columns_changed,
)
def indexing_func(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT:
"""Perform indexing on `ArrayWrapper`."""
return self.indexing_func_meta(*args, **kwargs)["new_wrapper"]
@staticmethod
def select_from_flex_array(
arr: tp.ArrayLike,
row_idxs: tp.Union[int, tp.Array1d, slice] = None,
col_idxs: tp.Union[int, tp.Array1d, slice] = None,
rows_changed: bool = True,
columns_changed: bool = True,
rotate_rows: bool = False,
rotate_cols: bool = True,
) -> tp.Array2d:
"""Select rows and columns from a flexible array.
Always returns a 2-dim NumPy array."""
new_arr = arr_2d = reshaping.to_2d_array(arr)
if row_idxs is not None and rows_changed:
if arr_2d.shape[0] > 1:
if isinstance(row_idxs, slice):
max_idx = row_idxs.stop - 1
else:
row_idxs = reshaping.to_1d_array(row_idxs)
max_idx = np.max(row_idxs)
if arr_2d.shape[0] <= max_idx:
if rotate_rows and not isinstance(row_idxs, slice):
new_arr = new_arr[row_idxs % arr_2d.shape[0], :]
else:
new_arr = new_arr[row_idxs, :]
else:
new_arr = new_arr[row_idxs, :]
if col_idxs is not None and columns_changed:
if arr_2d.shape[1] > 1:
if isinstance(col_idxs, slice):
max_idx = col_idxs.stop - 1
else:
col_idxs = reshaping.to_1d_array(col_idxs)
max_idx = np.max(col_idxs)
if arr_2d.shape[1] <= max_idx:
if rotate_cols and not isinstance(col_idxs, slice):
new_arr = new_arr[:, col_idxs % arr_2d.shape[1]]
else:
new_arr = new_arr[:, col_idxs]
else:
new_arr = new_arr[:, col_idxs]
return new_arr
def get_resampler(self, *args, **kwargs) -> tp.Union[Resampler, tp.PandasResampler]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_resampler`."""
return self.index_acc.get_resampler(*args, **kwargs)
def resample_meta(self: ArrayWrapperT, *args, wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> dict:
"""Perform resampling on `ArrayWrapper` and also return metadata.
`*args` and `**kwargs` are passed to `ArrayWrapper.get_resampler`."""
resampler = self.get_resampler(*args, **kwargs)
if isinstance(resampler, Resampler):
_resampler = resampler
else:
_resampler = Resampler.from_pd_resampler(resampler)
if wrapper_kwargs is None:
wrapper_kwargs = {}
if "index" not in wrapper_kwargs:
wrapper_kwargs["index"] = _resampler.target_index
if "freq" not in wrapper_kwargs:
wrapper_kwargs["freq"] = dt.infer_index_freq(wrapper_kwargs["index"], freq=_resampler.target_freq)
new_wrapper = self.replace(**wrapper_kwargs)
return dict(resampler=resampler, new_wrapper=new_wrapper)
def resample(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT:
"""Perform resampling on `ArrayWrapper`.
Uses `ArrayWrapper.resample_meta`."""
return self.resample_meta(*args, **kwargs)["new_wrapper"]
@property
def wrapper(self) -> "ArrayWrapper":
return self
@property
def index(self) -> tp.Index:
"""Index."""
return self._index
@cached_property(whitelist=True)
def index_acc(self) -> BaseIDXAccessorT:
"""Get index accessor of the type `vectorbtpro.base.accessors.BaseIDXAccessor`."""
from vectorbtpro.base.accessors import BaseIDXAccessor
return BaseIDXAccessor(self.index, freq=self._freq)
@property
def ns_index(self) -> tp.Array1d:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.to_ns`."""
return self.index_acc.to_ns()
def get_period_ns_index(self, *args, **kwargs) -> tp.Array1d:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.to_period_ns`."""
return self.index_acc.to_period_ns(*args, **kwargs)
@property
def columns(self) -> tp.Index:
"""Columns."""
return self._columns
def get_columns(self, group_by: tp.GroupByLike = None) -> tp.Index:
"""Get group-aware `ArrayWrapper.columns`."""
return self.resolve(group_by=group_by).columns
@property
def name(self) -> tp.Any:
"""Name."""
if self.ndim == 1:
if self.columns[0] == 0:
return None
return self.columns[0]
return None
def get_name(self, group_by: tp.GroupByLike = None) -> tp.Any:
"""Get group-aware `ArrayWrapper.name`."""
return self.resolve(group_by=group_by).name
@property
def ndim(self) -> int:
"""Number of dimensions."""
return self._ndim
def get_ndim(self, group_by: tp.GroupByLike = None) -> int:
"""Get group-aware `ArrayWrapper.ndim`."""
return self.resolve(group_by=group_by).ndim
@property
def shape(self) -> tp.Shape:
"""Shape."""
if self.ndim == 1:
return (len(self.index),)
return len(self.index), len(self.columns)
def get_shape(self, group_by: tp.GroupByLike = None) -> tp.Shape:
"""Get group-aware `ArrayWrapper.shape`."""
return self.resolve(group_by=group_by).shape
@property
def shape_2d(self) -> tp.Shape:
"""Shape as if the instance was two-dimensional."""
if self.ndim == 1:
return self.shape[0], 1
return self.shape
def get_shape_2d(self, group_by: tp.GroupByLike = None) -> tp.Shape:
"""Get group-aware `ArrayWrapper.shape_2d`."""
return self.resolve(group_by=group_by).shape_2d
def get_freq(self, *args, **kwargs) -> tp.Union[None, float, tp.PandasFrequency]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_freq`."""
return self.index_acc.get_freq(*args, **kwargs)
@property
def freq(self) -> tp.Optional[tp.PandasFrequency]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.freq`."""
return self.index_acc.freq
@property
def ns_freq(self) -> tp.Optional[int]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.ns_freq`."""
return self.index_acc.ns_freq
@property
def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.any_freq`."""
return self.index_acc.any_freq
@property
def periods(self) -> int:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.periods`."""
return self.index_acc.periods
@property
def dt_periods(self) -> float:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.dt_periods`."""
return self.index_acc.dt_periods
def arr_to_timedelta(self, *args, **kwargs) -> tp.Union[pd.Index, tp.MaybeArray]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.arr_to_timedelta`."""
return self.index_acc.arr_to_timedelta(*args, **kwargs)
@property
def parse_index(self) -> tp.Optional[bool]:
"""Whether to try to convert the index into a datetime index.
Applied during the initialization and passed to `vectorbtpro.utils.datetime_.prepare_dt_index`."""
return self._parse_index
@property
def column_only_select(self) -> bool:
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
column_only_select = self._column_only_select
if column_only_select is None:
column_only_select = wrapping_cfg["column_only_select"]
return column_only_select
@property
def range_only_select(self) -> bool:
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
range_only_select = self._range_only_select
if range_only_select is None:
range_only_select = wrapping_cfg["range_only_select"]
return range_only_select
@property
def group_select(self) -> bool:
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
group_select = self._group_select
if group_select is None:
group_select = wrapping_cfg["group_select"]
return group_select
@property
def grouper(self) -> Grouper:
"""Column grouper."""
return self._grouper
@property
def grouped_ndim(self) -> int:
"""Number of dimensions under column grouping."""
if self._grouped_ndim is None:
if self.grouper.is_grouped():
return 2 if self.grouper.get_group_count() > 1 else 1
return self.ndim
return self._grouped_ndim
@cached_method(whitelist=True)
def regroup(self: ArrayWrapperT, group_by: tp.GroupByLike, **kwargs) -> ArrayWrapperT:
"""Regroup this instance.
Only creates a new instance if grouping has changed, otherwise returns itself."""
if self.grouper.is_grouping_changed(group_by=group_by):
self.grouper.check_group_by(group_by=group_by)
grouped_ndim = None
if self.grouper.is_grouped(group_by=group_by):
if not self.grouper.is_group_count_changed(group_by=group_by):
grouped_ndim = self.grouped_ndim
return self.replace(grouped_ndim=grouped_ndim, group_by=group_by, **kwargs)
if len(kwargs) > 0:
return self.replace(**kwargs)
return self # important for keeping cache
def flip(self: ArrayWrapperT, **kwargs) -> ArrayWrapperT:
"""Flip index and columns."""
if "grouper" not in kwargs:
kwargs["grouper"] = None
return self.replace(index=self.columns, columns=self.index, **kwargs)
@cached_method(whitelist=True)
def resolve(self: ArrayWrapperT, group_by: tp.GroupByLike = None, **kwargs) -> ArrayWrapperT:
"""Resolve this instance.
Replaces columns and other metadata with groups."""
_self = self.regroup(group_by=group_by, **kwargs)
if _self.grouper.is_grouped():
return _self.replace(
columns=_self.grouper.get_index(),
ndim=_self.grouped_ndim,
grouped_ndim=None,
group_by=None,
)
return _self # important for keeping cache
def get_index_grouper(self, *args, **kwargs) -> Grouper:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`."""
return self.index_acc.get_grouper(*args, **kwargs)
def wrap(
self,
arr: tp.ArrayLike,
group_by: tp.GroupByLike = None,
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
zero_to_none: tp.Optional[bool] = None,
force_2d: bool = False,
fillna: tp.Optional[tp.Scalar] = None,
dtype: tp.Optional[tp.PandasDTypeLike] = None,
min_precision: tp.Union[None, int, str] = None,
max_precision: tp.Union[None, int, str] = None,
prec_float_only: tp.Optional[bool] = None,
prec_check_bounds: tp.Optional[bool] = None,
prec_strict: tp.Optional[bool] = None,
to_timedelta: bool = False,
to_index: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SeriesFrame:
"""Wrap a NumPy array using the stored metadata.
Runs the following pipeline:
1) Converts to NumPy array
2) Fills NaN (optional)
3) Wraps using index, columns, and dtype (optional)
4) Converts to index (optional)
5) Converts to timedelta using `ArrayWrapper.arr_to_timedelta` (optional)"""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if zero_to_none is None:
zero_to_none = wrapping_cfg["zero_to_none"]
if min_precision is None:
min_precision = wrapping_cfg["min_precision"]
if max_precision is None:
max_precision = wrapping_cfg["max_precision"]
if prec_float_only is None:
prec_float_only = wrapping_cfg["prec_float_only"]
if prec_check_bounds is None:
prec_check_bounds = wrapping_cfg["prec_check_bounds"]
if prec_strict is None:
prec_strict = wrapping_cfg["prec_strict"]
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
_self = self.resolve(group_by=group_by)
if index is None:
index = _self.index
if not isinstance(index, pd.Index):
index = pd.Index(index)
if columns is None:
columns = _self.columns
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if len(columns) == 1:
name = columns[0]
if zero_to_none and name == 0: # was a Series before
name = None
else:
name = None
def _apply_dtype(obj):
if dtype is None:
return obj
return obj.astype(dtype, errors="ignore")
def _wrap(arr):
orig_arr = arr
arr = np.asarray(arr)
if fillna is not None:
arr[pd.isnull(arr)] = fillna
shape_2d = (arr.shape[0] if arr.ndim > 0 else 1, arr.shape[1] if arr.ndim > 1 else 1)
target_shape_2d = (len(index), len(columns))
if shape_2d != target_shape_2d:
if isinstance(orig_arr, (pd.Series, pd.DataFrame)):
arr = reshaping.align_pd_arrays(orig_arr, to_index=index, to_columns=columns).values
arr = reshaping.broadcast_array_to(arr, target_shape_2d)
arr = reshaping.soft_to_ndim(arr, self.ndim)
if min_precision is not None:
arr = cast_to_min_precision(
arr,
min_precision,
float_only=prec_float_only,
)
if max_precision is not None:
arr = cast_to_max_precision(
arr,
max_precision,
float_only=prec_float_only,
check_bounds=prec_check_bounds,
strict=prec_strict,
)
if arr.ndim == 1:
if force_2d:
return _apply_dtype(pd.DataFrame(arr[:, None], index=index, columns=columns))
return _apply_dtype(pd.Series(arr, index=index, name=name))
if arr.ndim == 2:
if not force_2d and arr.shape[1] == 1 and _self.ndim == 1:
return _apply_dtype(pd.Series(arr[:, 0], index=index, name=name))
return _apply_dtype(pd.DataFrame(arr, index=index, columns=columns))
raise ValueError(f"{arr.ndim}-d input is not supported")
out = _wrap(arr)
if to_index:
# Convert to index
if checks.is_series(out):
out = out.map(lambda x: self.index[x] if x != -1 else np.nan)
else:
out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan)
if to_timedelta:
# Convert to timedelta
out = self.arr_to_timedelta(out, silence_warnings=silence_warnings)
return out
def wrap_reduced(
self,
arr: tp.ArrayLike,
group_by: tp.GroupByLike = None,
name_or_index: tp.NameIndex = None,
columns: tp.Optional[tp.IndexLike] = None,
force_1d: bool = False,
fillna: tp.Optional[tp.Scalar] = None,
dtype: tp.Optional[tp.PandasDTypeLike] = None,
to_timedelta: bool = False,
to_index: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.MaybeSeriesFrame:
"""Wrap result of reduction.
`name_or_index` can be the name of the resulting series if reducing to a scalar per column,
or the index of the resulting series/dataframe if reducing to an array per column.
`columns` can be set to override object's default columns.
See `ArrayWrapper.wrap` for the pipeline."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
_self = self.resolve(group_by=group_by)
if columns is None:
columns = _self.columns
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if to_index:
if dtype is None:
dtype = int_
if fillna is None:
fillna = -1
def _apply_dtype(obj):
if dtype is None:
return obj
return obj.astype(dtype, errors="ignore")
def _wrap_reduced(arr):
nonlocal name_or_index
if isinstance(arr, dict):
arr = reshaping.to_pd_array(arr)
if isinstance(arr, pd.Series):
if not checks.is_index_equal(arr.index, columns):
arr = arr.iloc[indexes.align_indexes(arr.index, columns)[0]]
arr = np.asarray(arr)
if force_1d and arr.ndim == 0:
arr = arr[None]
if fillna is not None:
if arr.ndim == 0:
if pd.isnull(arr):
arr = fillna
else:
arr[pd.isnull(arr)] = fillna
if arr.ndim == 0:
# Scalar per Series/DataFrame
return _apply_dtype(pd.Series(arr[None]))[0]
if arr.ndim == 1:
if not force_1d and _self.ndim == 1:
if arr.shape[0] == 1:
# Scalar per Series/DataFrame with one column
return _apply_dtype(pd.Series(arr))[0]
# Array per Series
sr_name = columns[0]
if sr_name == 0:
sr_name = None
if isinstance(name_or_index, str):
name_or_index = None
return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name))
# Scalar per column in DataFrame
if arr.shape[0] == 1 and len(columns) > 1:
arr = reshaping.broadcast_array_to(arr, len(columns))
return _apply_dtype(pd.Series(arr, index=columns, name=name_or_index))
if arr.ndim == 2:
if arr.shape[1] == 1 and _self.ndim == 1:
arr = reshaping.soft_to_ndim(arr, 1)
# Array per Series
sr_name = columns[0]
if sr_name == 0:
sr_name = None
if isinstance(name_or_index, str):
name_or_index = None
return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name))
# Array per column in DataFrame
if isinstance(name_or_index, str):
name_or_index = None
if arr.shape[0] == 1 and len(columns) > 1:
arr = reshaping.broadcast_array_to(arr, (arr.shape[0], len(columns)))
return _apply_dtype(pd.DataFrame(arr, index=name_or_index, columns=columns))
raise ValueError(f"{arr.ndim}-d input is not supported")
out = _wrap_reduced(arr)
if to_index:
# Convert to index
if checks.is_series(out):
out = out.map(lambda x: self.index[x] if x != -1 else np.nan)
elif checks.is_frame(out):
out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan)
else:
out = self.index[out] if out != -1 else np.nan
if to_timedelta:
# Convert to timedelta
out = self.arr_to_timedelta(out, silence_warnings=silence_warnings)
return out
def concat_arrs(
self,
*objs: tp.ArrayLike,
group_by: tp.GroupByLike = None,
wrap: bool = True,
**kwargs,
) -> tp.AnyArray1d:
"""Stack reduced objects along columns and wrap the final object."""
from vectorbtpro.base.merging import concat_arrays
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
new_objs = []
for obj in objs:
new_objs.append(reshaping.to_1d_array(obj))
stacked_obj = concat_arrays(new_objs)
if wrap:
_self = self.resolve(group_by=group_by)
return _self.wrap_reduced(stacked_obj, **kwargs)
return stacked_obj
def row_stack_arrs(
self,
*objs: tp.ArrayLike,
group_by: tp.GroupByLike = None,
wrap: bool = True,
**kwargs,
) -> tp.AnyArray:
"""Stack objects along rows and wrap the final object."""
from vectorbtpro.base.merging import row_stack_arrays
_self = self.resolve(group_by=group_by)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
new_objs = []
for obj in objs:
obj = reshaping.to_2d_array(obj)
if obj.shape[1] != _self.shape_2d[1]:
if obj.shape[1] != 1:
raise ValueError(f"Cannot broadcast {obj.shape[1]} to {_self.shape_2d[1]} columns")
obj = np.repeat(obj, _self.shape_2d[1], axis=1)
new_objs.append(obj)
stacked_obj = row_stack_arrays(new_objs)
if wrap:
return _self.wrap(stacked_obj, **kwargs)
return stacked_obj
def column_stack_arrs(
self,
*objs: tp.ArrayLike,
reindex_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
wrap: bool = True,
**kwargs,
) -> tp.AnyArray2d:
"""Stack objects along columns and wrap the final object.
`reindex_kwargs` will be passed to
[pandas.DataFrame.reindex](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.reindex.html)."""
from vectorbtpro.base.merging import column_stack_arrays
_self = self.resolve(group_by=group_by)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
new_objs = []
for obj in objs:
if not checks.is_index_equal(obj.index, _self.index, check_names=False):
was_bool = (isinstance(obj, pd.Series) and obj.dtype == "bool") or (
isinstance(obj, pd.DataFrame) and (obj.dtypes == "bool").all()
)
obj = obj.reindex(_self.index, **resolve_dict(reindex_kwargs))
is_object = (isinstance(obj, pd.Series) and obj.dtype == "object") or (
isinstance(obj, pd.DataFrame) and (obj.dtypes == "object").all()
)
if was_bool and is_object:
obj = obj.astype(None)
new_objs.append(reshaping.to_2d_array(obj))
stacked_obj = column_stack_arrays(new_objs)
if wrap:
return _self.wrap(stacked_obj, **kwargs)
return stacked_obj
def dummy(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Create a dummy Series/DataFrame."""
_self = self.resolve(group_by=group_by)
return _self.wrap(np.empty(_self.shape), **kwargs)
def fill(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Fill a Series/DataFrame."""
_self = self.resolve(group_by=group_by)
return _self.wrap(np.full(_self.shape_2d, fill_value), **kwargs)
def fill_reduced(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Fill a reduced Series/DataFrame."""
_self = self.resolve(group_by=group_by)
return _self.wrap_reduced(np.full(_self.shape_2d[1], fill_value), **kwargs)
def apply_to_index(
self: ArrayWrapperT,
apply_func: tp.Callable,
*args,
axis: tp.Optional[int] = None,
**kwargs,
) -> ArrayWrapperT:
if axis is None:
axis = 0 if self.ndim == 1 else 1
if self.ndim == 1 and axis == 1:
raise TypeError("Axis 1 is not supported for one dimension")
checks.assert_in(axis, (0, 1))
if axis == 1:
return self.replace(columns=apply_func(self.columns, *args, **kwargs))
return self.replace(index=apply_func(self.index, *args, **kwargs))
def get_index_points(self, *args, **kwargs) -> tp.Array1d:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_points`."""
return self.index_acc.get_points(*args, **kwargs)
def get_index_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_ranges`."""
return self.index_acc.get_ranges(*args, **kwargs)
def fill_and_set(
self,
idx_setter: tp.Union[index_dict, IdxSetter, IdxSetterFactory],
keep_flex: bool = False,
fill_value: tp.Scalar = np.nan,
**kwargs,
) -> tp.AnyArray:
"""Fill a new array using an index object such as `vectorbtpro.base.indexing.index_dict`.
Will be wrapped with `vectorbtpro.base.indexing.IdxSetter` if not already.
Will call `vectorbtpro.base.indexing.IdxSetter.fill_and_set`.
Usage:
* Set a single row:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", periods=5)
>>> columns = pd.Index(["a", "b", "c"])
>>> wrapper = vbt.ArrayWrapper(index, columns)
>>> wrapper.fill_and_set(vbt.index_dict({
... 1: 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... "2020-01-02": 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... "2020-01-02": [1, 2, 3]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
```
* Set multiple rows:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... (1, 3): [2, 3]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 3.0 3.0 3.0
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... ("2020-01-02", "2020-01-04"): [[1, 2, 3], [4, 5, 6]]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 4.0 5.0 6.0
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... ("2020-01-02", "2020-01-04"): [[1, 2, 3]]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 1.0 2.0 3.0
2020-01-05 NaN NaN NaN
```
* Set rows using slices:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.hslice(1, 3): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 2.0 2.0 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.hslice("2020-01-02", "2020-01-04"): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 2.0 2.0 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... ((0, 2), (3, 5)): [[1], [2]]
... }))
a b c
2020-01-01 1.0 1.0 1.0
2020-01-02 1.0 1.0 1.0
2020-01-03 NaN NaN NaN
2020-01-04 2.0 2.0 2.0
2020-01-05 2.0 2.0 2.0
>>> wrapper.fill_and_set(vbt.index_dict({
... ((0, 2), (3, 5)): [[1, 2, 3], [4, 5, 6]]
... }))
a b c
2020-01-01 1.0 2.0 3.0
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 4.0 5.0 6.0
2020-01-05 4.0 5.0 6.0
```
* Set rows using index points:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.pointidx(every="2D"): 2
... }))
a b c
2020-01-01 2.0 2.0 2.0
2020-01-02 NaN NaN NaN
2020-01-03 2.0 2.0 2.0
2020-01-04 NaN NaN NaN
2020-01-05 2.0 2.0 2.0
```
* Set rows using index ranges:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.rangeidx(
... start=("2020-01-01", "2020-01-03"),
... end=("2020-01-02", "2020-01-05")
... ): 2
... }))
a b c
2020-01-01 2.0 2.0 2.0
2020-01-02 NaN NaN NaN
2020-01-03 2.0 2.0 2.0
2020-01-04 2.0 2.0 2.0
2020-01-05 NaN NaN NaN
```
* Set column indices:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx("a"): 2
... }))
a b c
2020-01-01 2.0 NaN NaN
2020-01-02 2.0 NaN NaN
2020-01-03 2.0 NaN NaN
2020-01-04 2.0 NaN NaN
2020-01-05 2.0 NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx(("a", "b")): [1, 2]
... }))
a b c
2020-01-01 1.0 2.0 NaN
2020-01-02 1.0 2.0 NaN
2020-01-03 1.0 2.0 NaN
2020-01-04 1.0 2.0 NaN
2020-01-05 1.0 2.0 NaN
>>> multi_columns = pd.MultiIndex.from_arrays(
... [["a", "a", "b", "b"], [1, 2, 1, 2]],
... names=["c1", "c2"]
... )
>>> multi_wrapper = vbt.ArrayWrapper(index, multi_columns)
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx(("a", 2)): 2
... }))
c1 a b
c2 1 2 1 2
2020-01-01 NaN 2.0 NaN NaN
2020-01-02 NaN 2.0 NaN NaN
2020-01-03 NaN 2.0 NaN NaN
2020-01-04 NaN 2.0 NaN NaN
2020-01-05 NaN 2.0 NaN NaN
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx("b", level="c1"): [3, 4]
... }))
c1 a b
c2 1 2 1 2
2020-01-01 NaN NaN 3.0 4.0
2020-01-02 NaN NaN 3.0 4.0
2020-01-03 NaN NaN 3.0 4.0
2020-01-04 NaN NaN 3.0 4.0
2020-01-05 NaN NaN 3.0 4.0
```
* Set row and column indices:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(2, 2): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 NaN NaN NaN
2020-01-03 NaN NaN 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(("2020-01-01", "2020-01-03"), 2): [1, 2]
... }))
a b c
2020-01-01 NaN NaN 1.0
2020-01-02 NaN NaN NaN
2020-01-03 NaN NaN 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(("2020-01-01", "2020-01-03"), (0, 2)): [[1, 2], [3, 4]]
... }))
a b c
2020-01-01 1.0 NaN 2.0
2020-01-02 NaN NaN NaN
2020-01-03 3.0 NaN 4.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(
... vbt.pointidx(every="2d"),
... vbt.colidx(1, level="c2")
... ): [[1, 2]]
... }))
c1 a b
c2 1 2 1 2
2020-01-01 1.0 NaN 2.0 NaN
2020-01-02 NaN NaN NaN NaN
2020-01-03 1.0 NaN 2.0 NaN
2020-01-04 NaN NaN NaN NaN
2020-01-05 1.0 NaN 2.0 NaN
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(
... vbt.pointidx(every="2d"),
... vbt.colidx(1, level="c2")
... ): [[1], [2], [3]]
... }))
c1 a b
c2 1 2 1 2
2020-01-01 1.0 NaN 1.0 NaN
2020-01-02 NaN NaN NaN NaN
2020-01-03 2.0 NaN 2.0 NaN
2020-01-04 NaN NaN NaN NaN
2020-01-05 3.0 NaN 3.0 NaN
```
* Set rows using a template:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.RepEval("index.day % 2 == 0"): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 2.0 2.0 2.0
2020-01-05 NaN NaN NaN
```
"""
if isinstance(idx_setter, index_dict):
idx_setter = IdxDict(idx_setter)
if isinstance(idx_setter, IdxSetterFactory):
idx_setter = idx_setter.get()
if not isinstance(idx_setter, IdxSetter):
raise ValueError("Index setter factory must return exactly one index setter")
checks.assert_instance_of(idx_setter, IdxSetter)
arr = idx_setter.fill_and_set(
self.shape,
keep_flex=keep_flex,
fill_value=fill_value,
index=self.index,
columns=self.columns,
freq=self.freq,
**kwargs,
)
if not keep_flex:
return self.wrap(arr, group_by=False)
return arr
WrappingT = tp.TypeVar("WrappingT", bound="Wrapping")
class Wrapping(Configured, HasWrapper, IndexApplier, AttrResolverMixin):
"""Class that uses `ArrayWrapper` globally."""
@classmethod
def resolve_row_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Wrapping` after stacking along rows."""
return kwargs
@classmethod
def resolve_column_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Wrapping` after stacking along columns."""
return kwargs
@classmethod
def resolve_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Wrapping` after stacking.
Should be called after `Wrapping.resolve_row_stack_kwargs` or `Wrapping.resolve_column_stack_kwargs`."""
return cls.resolve_merge_kwargs(*[wrapping.config for wrapping in wrappings], **kwargs)
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[WrappingT],
*objs: tp.MaybeTuple[WrappingT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> WrappingT:
"""Stack multiple `Wrapping` instances along rows.
Should use `ArrayWrapper.row_stack`."""
raise NotImplementedError
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[WrappingT],
*objs: tp.MaybeTuple[WrappingT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> WrappingT:
"""Stack multiple `Wrapping` instances along columns.
Should use `ArrayWrapper.column_stack`."""
raise NotImplementedError
def __init__(self, wrapper: ArrayWrapper, **kwargs) -> None:
checks.assert_instance_of(wrapper, ArrayWrapper)
self._wrapper = wrapper
Configured.__init__(self, wrapper=wrapper, **kwargs)
HasWrapper.__init__(self)
AttrResolverMixin.__init__(self)
def indexing_func(self: WrappingT, *args, **kwargs) -> WrappingT:
"""Perform indexing on `Wrapping`."""
new_wrapper = self.wrapper.indexing_func(
*args,
column_only_select=self.column_only_select,
range_only_select=self.range_only_select,
group_select=self.group_select,
**kwargs,
)
return self.replace(wrapper=new_wrapper)
def resample(self: WrappingT, *args, **kwargs) -> WrappingT:
"""Perform resampling on `Wrapping`.
When overriding, make sure to create a resampler by passing `*args` and `**kwargs`
to `ArrayWrapper.get_resampler`."""
raise NotImplementedError
@property
def wrapper(self) -> ArrayWrapper:
return self._wrapper
def apply_to_index(
self: ArrayWrapperT,
apply_func: tp.Callable,
*args,
axis: tp.Optional[int] = None,
**kwargs,
) -> ArrayWrapperT:
if axis is None:
axis = 0 if self.wrapper.ndim == 1 else 1
if self.wrapper.ndim == 1 and axis == 1:
raise TypeError("Axis 1 is not supported for one dimension")
checks.assert_in(axis, (0, 1))
if axis == 1:
new_wrapper = self.wrapper.replace(columns=apply_func(self.wrapper.columns, *args, **kwargs))
else:
new_wrapper = self.wrapper.replace(index=apply_func(self.wrapper.index, *args, **kwargs))
return self.replace(wrapper=new_wrapper)
@property
def column_only_select(self) -> bool:
column_only_select = getattr(self, "_column_only_select", None)
if column_only_select is None:
return self.wrapper.column_only_select
return column_only_select
@property
def range_only_select(self) -> bool:
range_only_select = getattr(self, "_range_only_select", None)
if range_only_select is None:
return self.wrapper.range_only_select
return range_only_select
@property
def group_select(self) -> bool:
group_select = getattr(self, "_group_select", None)
if group_select is None:
return self.wrapper.group_select
return group_select
def regroup(self: WrappingT, group_by: tp.GroupByLike, **kwargs) -> WrappingT:
"""Regroup this instance.
Only creates a new instance if grouping has changed, otherwise returns itself.
`**kwargs` will be passed to `ArrayWrapper.regroup`."""
if self.wrapper.grouper.is_grouping_changed(group_by=group_by):
self.wrapper.grouper.check_group_by(group_by=group_by)
return self.replace(wrapper=self.wrapper.regroup(group_by, **kwargs))
return self # important for keeping cache
def resolve_self(
self: AttrResolverMixinT,
cond_kwargs: tp.KwargsLike = None,
custom_arg_names: tp.Optional[tp.Set[str]] = None,
impacts_caching: bool = True,
silence_warnings: tp.Optional[bool] = None,
) -> AttrResolverMixinT:
"""Resolve self.
Creates a copy of this instance if a different `freq` can be found in `cond_kwargs`."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if cond_kwargs is None:
cond_kwargs = {}
if custom_arg_names is None:
custom_arg_names = set()
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
if "freq" in cond_kwargs:
wrapper_copy = self.wrapper.replace(freq=cond_kwargs["freq"])
if wrapper_copy.freq != self.wrapper.freq:
if not silence_warnings:
warn(
f"Changing the frequency will create a copy of this instance. "
f"Consider setting it upon instantiation to re-use existing cache."
)
self_copy = self.replace(wrapper=wrapper_copy)
for alias in self.self_aliases:
if alias not in custom_arg_names:
cond_kwargs[alias] = self_copy
cond_kwargs["freq"] = self_copy.wrapper.freq
if impacts_caching:
cond_kwargs["use_caching"] = False
return self_copy
return self
</file>
<file path="data/custom/__init__.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with custom data classes."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.data.custom.alpaca import *
from vectorbtpro.data.custom.av import *
from vectorbtpro.data.custom.bento import *
from vectorbtpro.data.custom.binance import *
from vectorbtpro.data.custom.ccxt import *
from vectorbtpro.data.custom.csv import *
from vectorbtpro.data.custom.custom import *
from vectorbtpro.data.custom.db import *
from vectorbtpro.data.custom.duckdb import *
from vectorbtpro.data.custom.feather import *
from vectorbtpro.data.custom.file import *
from vectorbtpro.data.custom.finpy import *
from vectorbtpro.data.custom.gbm import *
from vectorbtpro.data.custom.gbm_ohlc import *
from vectorbtpro.data.custom.hdf import *
from vectorbtpro.data.custom.local import *
from vectorbtpro.data.custom.ndl import *
from vectorbtpro.data.custom.parquet import *
from vectorbtpro.data.custom.polygon import *
from vectorbtpro.data.custom.random import *
from vectorbtpro.data.custom.random_ohlc import *
from vectorbtpro.data.custom.remote import *
from vectorbtpro.data.custom.sql import *
from vectorbtpro.data.custom.synthetic import *
from vectorbtpro.data.custom.tv import *
from vectorbtpro.data.custom.yf import *
</file>
<file path="data/custom/alpaca.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `AlpacaData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names
try:
if not tp.TYPE_CHECKING:
raise ImportError
from alpaca.common.rest import RESTClient as AlpacaClientT
except ImportError:
AlpacaClientT = "AlpacaClient"
__all__ = [
"AlpacaData",
]
AlpacaDataT = tp.TypeVar("AlpacaDataT", bound="AlpacaData")
class AlpacaData(RemoteData):
"""Data class for fetching from Alpaca.
See https://github.com/alpacahq/alpaca-py for API.
See `AlpacaData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional for crypto):
```pycon
>>> from vectorbtpro import *
>>> vbt.AlpacaData.set_custom_settings(
... client_config=dict(
... api_key="YOUR_KEY",
... secret_key="YOUR_SECRET"
... )
... )
```
* Pull stock data:
```pycon
>>> data = vbt.AlpacaData.pull(
... "AAPL",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
* Pull crypto data:
```pycon
>>> data = vbt.AlpacaData.pull(
... "BTC/USD",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.alpaca")
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
status: tp.Optional[str] = None,
asset_class: tp.Optional[str] = None,
exchange: tp.Optional[str] = None,
trading_client: tp.Optional[AlpacaClientT] = None,
client_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.
Arguments `status`, `asset_class`, and `exchange` can be strings, such as `asset_class="crypto"`.
For possible values, take a look into `alpaca.trading.enums`.
!!! note
If you get an authorization error, make sure that you either enable or disable
the `paper` flag in `client_config` depending upon the account whose credentials you used.
By default, the credentials are assumed to be of a live trading account (`paper=False`)."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpaca")
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from alpaca.trading.enums import AssetStatus, AssetClass, AssetExchange
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if trading_client is None:
arg_names = get_func_arg_names(TradingClient.__init__)
client_config = {k: v for k, v in client_config.items() if k in arg_names}
trading_client = TradingClient(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
if status is not None:
if isinstance(status, str):
status = getattr(AssetStatus, status.upper())
if asset_class is not None:
if isinstance(asset_class, str):
asset_class = getattr(AssetClass, asset_class.upper())
if exchange is not None:
if isinstance(exchange, str):
exchange = getattr(AssetExchange, exchange.upper())
search_params = GetAssetsRequest(status=status, asset_class=asset_class, exchange=exchange)
assets = trading_client.get_all_assets(search_params)
all_symbols = []
for asset in assets:
symbol = asset.symbol
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def resolve_client(
cls,
client: tp.Optional[AlpacaClientT] = None,
client_type: tp.Optional[str] = None,
**client_config,
) -> AlpacaClientT:
"""Resolve the client.
If provided, must be of the type `alpaca.data.historical.CryptoHistoricalDataClient`
for `client_type="crypto"` and `alpaca.data.historical.StockHistoricalDataClient` for
`client_type="stocks"`. Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpaca")
from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient
client = cls.resolve_custom_setting(client, "client")
client_type = cls.resolve_custom_setting(client_type, "client_type")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
if client_type == "crypto":
arg_names = get_func_arg_names(CryptoHistoricalDataClient.__init__)
client_config = {k: v for k, v in client_config.items() if k in arg_names}
client = CryptoHistoricalDataClient(**client_config)
elif client_type == "stocks":
arg_names = get_func_arg_names(StockHistoricalDataClient.__init__)
client_config = {k: v for k, v in client_config.items() if k in arg_names}
client = StockHistoricalDataClient(**client_config)
else:
raise ValueError(f"Invalid client type: '{client_type}'")
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[AlpacaClientT] = None,
client_type: tp.Optional[str] = None,
client_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
adjustment: tp.Optional[str] = None,
feed: tp.Optional[str] = None,
limit: tp.Optional[int] = None,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Alpaca.
Args:
symbol (str): Symbol.
client (alpaca.common.rest.RESTClient): Client.
See `AlpacaData.resolve_client`.
client_type (str): Client type.
See `AlpacaData.resolve_client`.
Determined automatically based on the symbol. Crypto symbols contain "/".
client_config (dict): Client config.
See `AlpacaData.resolve_client`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
adjustment (str): Specifies the corporate action adjustment for the returned bars.
Options are: "raw", "split", "dividend" or "all". Default is "raw".
feed (str): The feed to pull market data from.
This is either "iex", "otc", or "sip". Feeds "sip" and "otc" are only available to
those with a subscription. Default is "iex" for free plans and "sip" for paid.
limit (int): The maximum number of returned items.
For defaults, see `custom.alpaca` in `vectorbtpro._settings.data`.
Global settings can be provided per exchange id using the `exchanges` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpaca")
from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient
from alpaca.data.requests import CryptoBarsRequest, StockBarsRequest
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
if client_type is None:
client_type = "crypto" if "/" in symbol else "stocks"
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, client_type=client_type, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
adjustment = cls.resolve_custom_setting(adjustment, "adjustment")
feed = cls.resolve_custom_setting(feed, "feed")
limit = cls.resolve_custom_setting(limit, "limit")
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "m":
unit = TimeFrameUnit.Minute
elif unit == "h":
unit = TimeFrameUnit.Hour
elif unit == "D":
unit = TimeFrameUnit.Day
elif unit == "W":
unit = TimeFrameUnit.Week
elif unit == "M":
unit = TimeFrameUnit.Month
else:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
timeframe = TimeFrame(multiplier, unit)
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start_str = start.replace(tzinfo=None).isoformat("T")
else:
start_str = None
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end_str = end.replace(tzinfo=None).isoformat("T")
else:
end_str = None
if isinstance(client, CryptoHistoricalDataClient):
request = CryptoBarsRequest(
symbol_or_symbols=symbol,
timeframe=timeframe,
start=start_str,
end=end_str,
limit=limit,
)
df = client.get_crypto_bars(request).df
elif isinstance(client, StockHistoricalDataClient):
request = StockBarsRequest(
symbol_or_symbols=symbol,
timeframe=timeframe,
start=start_str,
end=end_str,
limit=limit,
adjustment=adjustment,
feed=feed,
)
df = client.get_stock_bars(request).df
else:
raise TypeError(f"Invalid client of type {type(client)}")
df = df.droplevel("symbol", axis=0)
df.index = df.index.rename("Open time")
df.rename(
columns={
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"volume": "Volume",
"trade_count": "Trade count",
"vwap": "VWAP",
},
inplace=True,
)
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
if "Trade count" in df.columns:
df["Trade count"] = df["Trade count"].astype(int, errors="ignore")
if "VWAP" in df.columns:
df["VWAP"] = df["VWAP"].astype(float)
if not df.empty:
if start is not None:
start = dt.to_timestamp(start, tz=df.index.tz)
if df.index[0] < start:
df = df[df.index >= start]
if end is not None:
end = dt.to_timestamp(end, tz=df.index.tz)
if df.index[-1] >= end:
df = df[df.index < end]
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/av.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `AVData`."""
import re
import urllib.parse
from functools import lru_cache
import numpy as np
import pandas as pd
import requests
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.module_ import check_installed
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from alpha_vantage.alphavantage import AlphaVantage as AlphaVantageT
except ImportError:
AlphaVantageT = "AlphaVantage"
__all__ = [
"AVData",
]
__pdoc__ = {}
AVDataT = tp.TypeVar("AVDataT", bound="AVData")
class AVData(RemoteData):
"""Data class for fetching from Alpha Vantage.
See https://www.alphavantage.co/documentation/ for API.
Apart of using https://github.com/RomelTorres/alpha_vantage package, this class can also
parse the API documentation with `AVData.parse_api_meta` using `BeautifulSoup4` and build
the API query based on this metadata (pass `use_parser=True`).
This approach is the most flexible we can get since we can instantly react to Alpha Vantage's changes
in the API. If the data provider changes its API documentation, you can always adapt the parsing
procedure by overriding `AVData.parse_api_meta`.
If parser still fails, you can disable parsing entirely and specify all information manually
by setting `function` and disabling `match_params`
See `AVData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.AVData.set_custom_settings(
... apikey="YOUR_KEY"
... )
```
* Pull data:
```pycon
>>> data = vbt.AVData.pull(
... "GOOGL",
... timeframe="1 day",
... )
>>> data = vbt.AVData.pull(
... "BTC_USD",
... timeframe="30 minutes", # premium?
... category="digital-currency",
... outputsize="full"
... )
>>> data = vbt.AVData.pull(
... "REAL_GDP",
... category="economic-indicators"
... )
>>> data = vbt.AVData.pull(
... "IBM",
... category="technical-indicators",
... function="STOCHRSI",
... params=dict(fastkperiod=14)
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.av")
@classmethod
def list_symbols(cls, keywords: str, apikey: tp.Optional[str] = None, sort: bool = True) -> tp.List[str]:
"""List all symbols."""
apikey = cls.resolve_custom_setting(apikey, "apikey")
query = dict()
query["function"] = "SYMBOL_SEARCH"
query["keywords"] = keywords
query["datatype"] = "csv"
query["apikey"] = apikey
url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(query)
df = pd.read_csv(url)
if sort:
return sorted(dict.fromkeys(df["symbol"].tolist()))
return list(dict.fromkeys(df["symbol"].tolist()))
@classmethod
@lru_cache()
def parse_api_meta(cls) -> dict:
"""Parse API metadata from the documentation at https://www.alphavantage.co/documentation
Cached class method. To avoid re-parsing the same metadata in different runtimes, save it manually."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("bs4")
from bs4 import BeautifulSoup
page = requests.get("https://www.alphavantage.co/documentation")
soup = BeautifulSoup(page.content, "html.parser")
api_meta = {}
for section in soup.select("article section"):
category = {}
function = None
function_args = dict(req_args=set(), opt_args=set())
for tag in section.find_all(True):
if tag.name == "h6":
if function is not None and tag.select("b")[0].getText().strip() == "API Parameters":
category[function] = function_args
function = None
function_args = dict(req_args=set(), opt_args=set())
if tag.name == "b":
b_text = tag.getText().strip()
if b_text.startswith("❚ Required"):
arg = tag.select("code")[0].getText().strip()
function_args["req_args"].add(arg)
if tag.name == "p":
p_text = tag.getText().strip()
if p_text.startswith("❚ Optional"):
arg = tag.select("code")[0].getText().strip()
function_args["opt_args"].add(arg)
if tag.name == "code":
code_text = tag.getText().strip()
if code_text.startswith("function="):
function = code_text.replace("function=", "")
if function is not None:
category[function] = function_args
api_meta[section.select("h2")[0]["id"]] = category
return api_meta
@classmethod
def fetch_symbol(
cls,
symbol: str,
use_parser: tp.Optional[bool] = None,
apikey: tp.Optional[str] = None,
api_meta: tp.Optional[dict] = None,
category: tp.Union[None, str, AlphaVantageT, tp.Type[AlphaVantageT]] = None,
function: tp.Union[None, str, tp.Callable] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
adjusted: tp.Optional[bool] = None,
extended: tp.Optional[bool] = None,
slice: tp.Optional[str] = None,
series_type: tp.Optional[str] = None,
time_period: tp.Optional[int] = None,
outputsize: tp.Optional[str] = None,
match_params: tp.Optional[bool] = None,
params: tp.KwargsLike = None,
read_csv_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SymbolData:
"""Fetch a symbol from Alpha Vantage.
If `use_parser` is False, or None and `alpha_vantage` is installed, uses the package.
Otherwise, parses the API documentation and pulls data directly.
See https://www.alphavantage.co/documentation/ for API endpoints and their parameters.
!!! note
Supports the CSV format only.
Args:
symbol (str): Symbol.
May combine symbol/from_currency and market/to_currency using an underscore.
use_parser (bool): Whether to use the parser instead of the `alpha_vantage` package.
apikey (str): API key.
api_meta (dict): API meta.
If None, will use `AVData.parse_api_meta` if `function` is not provided
or `match_params` is True.
category (str or AlphaVantage): API category of your choice.
Used if `function` is not provided or `match_params` is True.
Supported are:
* `alpha_vantage.alphavantage.AlphaVantage` instance, class, or class name
* "time-series-data" or "time-series"
* "fundamental-data" or "fundamentals"
* "foreign-exchange", "forex", or "fx"
* "digital-currency", "cryptocurrencies", "cryptocurrency", or "crypto"
* "commodities"
* "economic-indicators"
* "technical-indicators" or "indicators"
function (str or callable): API function of your choice.
If None, will try to resolve it based on other arguments, such as `timeframe`,
`adjusted`, and `extended`. Required for technical indicators, economic indicators,
and fundamental data.
See the keys in sub-dictionaries returned by `AVData.parse_api_meta`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
For time series, forex, and crypto, looks for interval type in the function's name.
Defaults to "60min" if extended, otherwise to "daily".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
adjusted (bool): Whether to return time series adjusted by historical split and dividend events.
extended (bool): Whether to return historical intraday time series for the trailing 2 years.
slice (str): Slice of the trailing 2 years.
series_type (str): The desired price type in the time series.
time_period (int): Number of data points used to calculate each window value.
outputsize (str): Output size.
Supported are
* "compact" that returns only the latest 100 data points
* "full" that returns the full-length time series
match_params (bool): Whether to match parameters with the ones required by the endpoint.
Otherwise, uses only (resolved) `function`, `apikey`, `datatype="csv"`, and `params`.
params: Additional keyword arguments passed as key/value pairs in the URL.
read_csv_kwargs (dict): Keyword arguments passed to `pd.read_csv`.
silence_warnings (bool): Whether to silence all warnings.
For defaults, see `custom.av` in `vectorbtpro._settings.data`.
"""
use_parser = cls.resolve_custom_setting(use_parser, "use_parser")
apikey = cls.resolve_custom_setting(apikey, "apikey")
api_meta = cls.resolve_custom_setting(api_meta, "api_meta")
category = cls.resolve_custom_setting(category, "category")
function = cls.resolve_custom_setting(function, "function")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
adjusted = cls.resolve_custom_setting(adjusted, "adjusted")
extended = cls.resolve_custom_setting(extended, "extended")
slice = cls.resolve_custom_setting(slice, "slice")
series_type = cls.resolve_custom_setting(series_type, "series_type")
time_period = cls.resolve_custom_setting(time_period, "time_period")
outputsize = cls.resolve_custom_setting(outputsize, "outputsize")
read_csv_kwargs = cls.resolve_custom_setting(read_csv_kwargs, "read_csv_kwargs", merge=True)
match_params = cls.resolve_custom_setting(match_params, "match_params")
params = cls.resolve_custom_setting(params, "params", merge=True)
silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings")
if use_parser is None:
if api_meta is None and check_installed("alpha_vantage"):
use_parser = False
else:
use_parser = True
if not use_parser:
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpha_vantage")
if use_parser and api_meta is None and (function is None or match_params):
if not silence_warnings and cls.parse_api_meta.cache_info().misses == 0:
warn("Parsing API documentation...")
try:
api_meta = cls.parse_api_meta()
except Exception as e:
raise ValueError("Can't fetch/parse the API documentation. Specify function and disable match_params.")
freq = timeframe
interval = None
interval_type = None
if timeframe is not None:
if not isinstance(timeframe, str):
raise ValueError(f"Invalid timeframe: '{timeframe}'")
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "m":
interval = str(multiplier) + "min"
interval_type = "intraday"
elif unit == "h":
interval = str(60 * multiplier) + "min"
interval_type = "intraday"
elif unit == "D":
interval = "daily"
interval_type = "daily"
elif unit == "W":
interval = "weekly"
interval_type = "weekly"
elif unit == "M":
interval = "monthly"
interval_type = "monthly"
elif unit == "Q":
interval = "quarterly"
interval_type = "quarterly"
elif unit == "Y":
interval = "annual"
interval_type = "annual"
if interval is None and multiplier > 1:
raise ValueError("Multipliers are supported only for intraday timeframes")
else:
if extended:
interval_type = "intraday"
interval = "60min"
else:
interval_type = "daily"
interval = "daily"
if category is not None:
if isinstance(category, str):
if category.lower() in ("time-series-data", "time-series", "timeseries"):
if use_parser:
category = "time-series-data"
else:
from alpha_vantage.timeseries import TimeSeries
category = TimeSeries
elif category.lower() in ("fundamentals", "fundamental-data", "fundamentaldata"):
if use_parser:
category = "fundamentals"
else:
from alpha_vantage.fundamentaldata import FundamentalData
category = FundamentalData
elif category.lower() in ("fx", "forex", "foreign-exchange", "foreignexchange"):
if use_parser:
category = "fx"
else:
from alpha_vantage.foreignexchange import ForeignExchange
category = ForeignExchange
elif category.lower() in ("digital-currency", "cryptocurrencies", "cryptocurrency", "crypto"):
if use_parser:
category = "digital-currency"
else:
from alpha_vantage.cryptocurrencies import CryptoCurrencies
category = CryptoCurrencies
elif category.lower() in ("commodities",):
if use_parser:
category = "commodities"
else:
raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.")
elif category.lower() in ("economic-indicators",):
if use_parser:
category = "economic-indicators"
else:
raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.")
elif category.lower() in ("technical-indicators", "techindicators", "indicators"):
if use_parser:
category = "technical-indicators"
else:
from alpha_vantage.techindicators import TechIndicators
category = TechIndicators
else:
raise ValueError(f"Invalid category: '{category}'")
else:
if use_parser:
raise TypeError("Category must be a string")
else:
from alpha_vantage.alphavantage import AlphaVantage
if isinstance(category, type):
if not issubclass(category, AlphaVantage):
raise TypeError("Category must be a subclass of AlphaVantage")
elif not isinstance(category, AlphaVantage):
raise TypeError("Category must be an instance of AlphaVantage")
if use_parser:
if function is None:
if category is not None:
if category in ("commodities", "economic-indicators"):
function = symbol
if function is None:
if category is None:
category = "time-series-data"
if category in ("fundamentals", "technical-indicators"):
raise ValueError("Function is required")
adjusted_in_functions = False
extended_in_functions = False
matched_functions = []
for k in api_meta[category]:
if interval_type is None or interval_type.upper() in k:
if "ADJUSTED" in k:
adjusted_in_functions = True
if "EXTENDED" in k:
extended_in_functions = True
matched_functions.append(k)
if adjusted_in_functions:
matched_functions = [
k
for k in matched_functions
if (adjusted and "ADJUSTED" in k) or (not adjusted and "ADJUSTED" not in k)
]
if extended_in_functions:
matched_functions = [
k
for k in matched_functions
if (extended and "EXTENDED" in k) or (not extended and "EXTENDED" not in k)
]
if len(matched_functions) == 0:
raise ValueError("No functions satisfy the requirements")
if len(matched_functions) > 1:
raise ValueError("More than one function satisfies the requirements")
function = matched_functions[0]
if match_params:
if function is not None and category is None:
category = None
for k, v in api_meta.items():
if function in v:
category = k
break
if category is None:
raise ValueError("Category is required")
req_args = api_meta[category][function]["req_args"]
opt_args = api_meta[category][function]["opt_args"]
args = set(req_args) | set(opt_args)
matched_params = dict()
matched_params["function"] = function
matched_params["datatype"] = "csv"
matched_params["apikey"] = apikey
if "symbol" in args and "market" in args:
matched_params["symbol"] = symbol.split("_")[0]
matched_params["market"] = symbol.split("_")[1]
elif "from_" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "from_currency" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "symbol" in args:
matched_params["symbol"] = symbol
if "interval" in args:
matched_params["interval"] = interval
if "adjusted" in args:
matched_params["adjusted"] = adjusted
if "extended" in args:
matched_params["extended"] = extended
if "extended_hours" in args:
matched_params["extended_hours"] = extended
if "slice" in args:
matched_params["slice"] = slice
if "series_type" in args:
matched_params["series_type"] = series_type
if "time_period" in args:
matched_params["time_period"] = time_period
if "outputsize" in args:
matched_params["outputsize"] = outputsize
for k, v in params.items():
if k in args:
matched_params[k] = v
else:
raise ValueError(f"Function '{function}' does not expect parameter '{k}'")
for arg in req_args:
if arg not in matched_params:
raise ValueError(f"Function '{function}' requires parameter '{arg}'")
else:
matched_params = dict(params)
matched_params["function"] = function
matched_params["apikey"] = apikey
matched_params["datatype"] = "csv"
url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(matched_params)
df = pd.read_csv(url, **read_csv_kwargs)
else:
from alpha_vantage.alphavantage import AlphaVantage
from alpha_vantage.timeseries import TimeSeries
from alpha_vantage.fundamentaldata import FundamentalData
from alpha_vantage.foreignexchange import ForeignExchange
from alpha_vantage.cryptocurrencies import CryptoCurrencies
from alpha_vantage.techindicators import TechIndicators
if isinstance(category, type) and issubclass(category, AlphaVantage):
category = category(key=apikey, output_format="pandas")
if function is None:
if category is None:
category = TimeSeries(key=apikey, output_format="pandas")
if isinstance(category, (TechIndicators, FundamentalData)):
raise ValueError("Function is required")
adjusted_in_methods = False
extended_in_methods = False
matched_methods = []
for k in dir(category):
if interval_type is None or interval_type in k:
if "adjusted" in k:
adjusted_in_methods = True
if "extended" in k:
extended_in_methods = True
matched_methods.append(k)
if adjusted_in_methods:
matched_methods = [
k
for k in matched_methods
if (adjusted and "adjusted" in k) or (not adjusted and "adjusted" not in k)
]
if extended_in_methods:
matched_methods = [
k
for k in matched_methods
if (extended and "extended" in k) or (not extended and "extended" not in k)
]
if len(matched_methods) == 0:
raise ValueError("No methods satisfy the requirements")
if len(matched_methods) > 1:
raise ValueError("More than one method satisfies the requirements")
function = matched_methods[0]
if isinstance(function, str):
function = function.lower()
if not function.startswith("get_"):
function = "get_" + function
if category is not None:
function = getattr(category, function)
else:
categories = [
TimeSeries,
FundamentalData,
ForeignExchange,
CryptoCurrencies,
TechIndicators,
]
matched_methods = []
for category in categories:
if function in dir(category):
matched_methods.append(getattr(category, function))
if len(matched_methods) == 0:
raise ValueError("No methods satisfy the requirements")
if len(matched_methods) > 1:
raise ValueError("More than one method satisfies the requirements")
function = matched_methods[0]
if match_params:
args = set(get_func_arg_names(function))
matched_params = dict()
if "symbol" in args and "market" in args:
matched_params["symbol"] = symbol.split("_")[0]
matched_params["market"] = symbol.split("_")[1]
elif "from_" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "from_currency" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "symbol" in args:
matched_params["symbol"] = symbol
if "interval" in args:
matched_params["interval"] = interval
if "adjusted" in args:
matched_params["adjusted"] = adjusted
if "extended" in args:
matched_params["extended"] = extended
if "extended_hours" in args:
matched_params["extended_hours"] = extended
if "slice" in args:
matched_params["slice"] = slice
if "series_type" in args:
matched_params["series_type"] = series_type
if "time_period" in args:
matched_params["time_period"] = time_period
if "outputsize" in args:
matched_params["outputsize"] = outputsize
else:
matched_params = dict(params)
df, df_metadata = function(**matched_params)
for k, v in df_metadata.items():
if "Time Zone" in k:
if tz is None:
if v.endswith(" Time"):
v = v[: -len(" Time")]
tz = v
df.index.name = None
new_columns = []
for c in df.columns:
new_c = re.sub(r"^\d+\w*\.\s*", "", c)
new_c = new_c[0].title() + new_c[1:]
if new_c.endswith(" (USD)"):
new_c = new_c[: -len(" (USD)")]
new_columns.append(new_c)
df = df.rename(columns=dict(zip(df.columns, new_columns)))
df = df.loc[:, ~df.columns.duplicated()]
for c in df.columns:
if df[c].dtype == "O":
df[c] = df[c].replace({".": np.nan})
df = df.apply(pd.to_numeric, errors="ignore")
if not df.empty and df.index[0] > df.index[1]:
df = df.iloc[::-1]
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None and tz is not None:
df = df.tz_localize(tz)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/bento.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `BentoData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names
try:
if not tp.TYPE_CHECKING:
raise ImportError
from databento import Historical as HistoricalT
except ImportError:
HistoricalT = "Historical"
__all__ = [
"BentoData",
]
class BentoData(RemoteData):
"""Data class for fetching from Databento.
See https://github.com/databento/databento-python for API.
See `BentoData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.BentoData.set_custom_settings(
... client_config=dict(
... key="YOUR_KEY"
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.BentoData.pull(
... "AAPL",
... dataset="XNAS.ITCH"
... )
```
```pycon
>>> data = vbt.BentoData.pull(
... "AAPL",
... dataset="XNAS.ITCH",
... timeframe="hourly",
... start="one week ago"
... )
```
```pycon
>>> data = vbt.BentoData.pull(
... "ES.FUT",
... dataset="GLBX.MDP3",
... stype_in="parent",
... schema="mbo",
... start="2022-06-10T14:30",
... end="2022-06-11",
... limit=1000
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.bento")
@classmethod
def resolve_client(cls, client: tp.Optional[HistoricalT] = None, **client_config) -> HistoricalT:
"""Resolve the client.
If provided, must be of the type `databento.historical.client.Historical`.
Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("databento")
from databento import Historical
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = Historical(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def get_cost(cls, symbols: tp.MaybeSymbols, **kwargs) -> float:
"""Get the cost of calling `BentoData.fetch_symbol` on one or more symbols."""
if isinstance(symbols, str):
symbols = [symbols]
costs = []
for symbol in symbols:
client, params = cls.fetch_symbol(symbol, **kwargs, return_params=True)
cost_arg_names = get_func_arg_names(client.metadata.get_cost)
for k in list(params.keys()):
if k not in cost_arg_names:
del params[k]
costs.append(client.metadata.get_cost(**params, mode="historical"))
return sum(costs)
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[HistoricalT] = None,
client_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
resolve_dates: tp.Optional[bool] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
dataset: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
return_params: bool = False,
df_kwargs: tp.KwargsLike = None,
**params,
) -> tp.Union[float, tp.SymbolData]:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Databento.
Args:
symbol (str): Symbol.
Symbol can be in the `DATASET:SYMBOL` format if `dataset` is None.
client (binance.client.Client): Client.
See `BentoData.resolve_client`.
client_config (dict): Client config.
See `BentoData.resolve_client`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
resolve_dates (bool): Whether to resolve `start` and `end`, or pass them as they are.
timeframe (str): Timeframe to create `schema` from.
Allows human-readable strings such as "1 minute".
If `timeframe` and `schema` are both not None, will raise an error.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
dataset (str): See `databento.historical.client.Historical.get_range`.
schema (str): See `databento.historical.client.Historical.get_range`.
return_params (bool): Whether to return the client and (final) parameters instead of data.
Used by `BentoData.get_cost`.
df_kwargs (dict): Keyword arguments passed to `databento.common.dbnstore.DBNStore.to_df`.
**params: Keyword arguments passed to `databento.historical.client.Historical.get_range`.
For defaults, see `custom.bento` in `vectorbtpro._settings.data`.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("databento")
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
resolve_dates = cls.resolve_custom_setting(resolve_dates, "resolve_dates")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
dataset = cls.resolve_custom_setting(dataset, "dataset")
schema = cls.resolve_custom_setting(schema, "schema")
params = cls.resolve_custom_setting(params, "params", merge=True)
df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True)
if dataset is None:
if ":" in symbol:
dataset, symbol = symbol.split(":")
if timeframe is None and schema is None:
schema = "ohlcv-1d"
freq = "1d"
elif timeframe is not None:
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
timeframe = str(multiplier) + unit
if schema is None or schema.lower() == "ohlcv":
schema = f"ohlcv-{timeframe}"
else:
raise ValueError("Timeframe cannot be used together with schema")
else:
if schema.startswith("ohlcv-"):
freq = schema[len("ohlcv-") :]
else:
freq = None
if resolve_dates:
dataset_range = client.metadata.get_dataset_range(dataset)
if "start_date" in dataset_range:
start_date = dt.to_tzaware_timestamp(dataset_range["start_date"], naive_tz="utc", tz="utc")
else:
start_date = dt.to_tzaware_timestamp(dataset_range["start"], naive_tz="utc", tz="utc")
if "end_date" in dataset_range:
end_date = dt.to_tzaware_timestamp(dataset_range["end_date"], naive_tz="utc", tz="utc")
else:
end_date = dt.to_tzaware_timestamp(dataset_range["end"], naive_tz="utc", tz="utc")
if start is not None:
start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz="utc")
if start < start_date:
start = start_date
else:
start = start_date
if end is not None:
end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz="utc")
if end > end_date:
end = end_date
else:
end = end_date
if start.floor("d") == start:
start = start.date().isoformat()
else:
start = start.isoformat()
if end.floor("d") == end:
end = end.date().isoformat()
else:
end = end.isoformat()
params = merge_dicts(
dict(
dataset=dataset,
start=start,
end=end,
symbols=symbol,
schema=schema,
),
params,
)
if return_params:
return client, params
df = client.timeseries.get_range(**params).to_df(**df_kwargs)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/binance.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `BinanceData`."""
import time
import traceback
from functools import partial
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from binance.client import Client as BinanceClientT
except ImportError:
BinanceClientT = "BinanceClient"
__all__ = [
"BinanceData",
]
__pdoc__ = {}
BinanceDataT = tp.TypeVar("BinanceDataT", bound="BinanceData")
class BinanceData(RemoteData):
"""Data class for fetching from Binance.
See https://github.com/sammchardy/python-binance for API.
See `BinanceData.fetch_symbol` for arguments.
!!! note
If you are using an exchange from the US, Japan or other TLD then make sure pass `tld="us"`
in `client_config` when creating the client.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.BinanceData.set_custom_settings(
... client_config=dict(
... api_key="YOUR_KEY",
... api_secret="YOUR_SECRET"
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.BinanceData.pull(
... "BTCUSDT",
... start="2020-01-01",
... end="2021-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.binance")
_feature_config: tp.ClassVar[Config] = HybridConfig(
{
"Quote volume": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
"Taker base volume": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
"Taker quote volume": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
}
)
@property
def feature_config(self) -> Config:
return self._feature_config
@classmethod
def resolve_client(cls, client: tp.Optional[BinanceClientT] = None, **client_config) -> BinanceClientT:
"""Resolve the client.
If provided, must be of the type `binance.client.Client`.
Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("binance")
from binance.client import Client
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = Client(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
client: tp.Optional[BinanceClientT] = None,
client_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
all_symbols = []
for dct in client.get_exchange_info()["symbols"]:
symbol = dct["symbol"]
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[BinanceClientT] = None,
client_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
klines_type: tp.Union[None, int, str] = None,
limit: tp.Optional[int] = None,
delay: tp.Optional[float] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
**get_klines_kwargs,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Binance.
Args:
symbol (str): Symbol.
client (binance.client.Client): Client.
See `BinanceData.resolve_client`.
client_config (dict): Client config.
See `BinanceData.resolve_client`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
klines_type (int or str): Kline type.
See `binance.enums.HistoricalKlinesType`. Supports strings.
limit (int): The maximum number of returned items.
delay (float): Time to sleep after each request (in seconds).
show_progress (bool): Whether to show the progress bar.
pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.
silence_warnings (bool): Whether to silence all warnings.
**get_klines_kwargs: Keyword arguments passed to `binance.client.Client.get_klines`.
For defaults, see `custom.binance` in `vectorbtpro._settings.data`.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("binance")
from binance.enums import HistoricalKlinesType
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
klines_type = cls.resolve_custom_setting(klines_type, "klines_type")
if isinstance(klines_type, str):
klines_type = map_enum_fields(klines_type, HistoricalKlinesType)
if isinstance(klines_type, int):
klines_type = {i.value: i for i in HistoricalKlinesType}[klines_type]
limit = cls.resolve_custom_setting(limit, "limit")
delay = cls.resolve_custom_setting(delay, "delay")
show_progress = cls.resolve_custom_setting(show_progress, "show_progress")
pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True)
if "bar_id" not in pbar_kwargs:
pbar_kwargs["bar_id"] = "binance"
silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings")
get_klines_kwargs = cls.resolve_custom_setting(get_klines_kwargs, "get_klines_kwargs", merge=True)
# Prepare parameters
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
if unit == "D":
unit = "d"
elif unit == "W":
unit = "w"
timeframe = str(multiplier) + unit
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
first_valid_ts = client._get_earliest_valid_timestamp(symbol, timeframe, klines_type)
start_ts = max(start_ts, first_valid_ts)
else:
start_ts = None
prev_end_ts = None
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
else:
end_ts = None
def _ts_to_str(ts: tp.Optional[int]) -> str:
if ts is None:
return "?"
return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe)
def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool:
if start_ts is not None:
if d[0] < start_ts:
return False
if _prev_end_ts is not None:
if d[0] <= _prev_end_ts:
return False
if end_ts is not None:
if d[0] >= end_ts:
return False
return True
# Iteratively collect the data
data = []
try:
with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar:
pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts)))
while True:
# Fetch the klines for the next timeframe
next_data = client._klines(
symbol=symbol,
interval=timeframe,
limit=limit,
startTime=start_ts if prev_end_ts is None else prev_end_ts,
endTime=end_ts,
klines_type=klines_type,
**get_klines_kwargs,
)
next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data))
# Update the timestamps and the progress bar
if not len(next_data):
break
data += next_data
if start_ts is None:
start_ts = next_data[0][0]
pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0])))
pbar.update()
prev_end_ts = next_data[-1][0]
if end_ts is not None and prev_end_ts >= end_ts:
break
if delay is not None:
time.sleep(delay) # be kind to api
except Exception as e:
if not silence_warnings:
warn(traceback.format_exc())
warn(
f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. "
"Use update() method to fetch missing data."
)
# Convert data to a DataFrame
df = pd.DataFrame(
data,
columns=[
"Open time",
"Open",
"High",
"Low",
"Close",
"Volume",
"Close time",
"Quote volume",
"Trade count",
"Taker base volume",
"Taker quote volume",
"Ignore",
],
)
df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True)
df["Open"] = df["Open"].astype(float)
df["High"] = df["High"].astype(float)
df["Low"] = df["Low"].astype(float)
df["Close"] = df["Close"].astype(float)
df["Volume"] = df["Volume"].astype(float)
df["Quote volume"] = df["Quote volume"].astype(float)
df["Trade count"] = df["Trade count"].astype(int, errors="ignore")
df["Taker base volume"] = df["Taker base volume"].astype(float)
df["Taker quote volume"] = df["Taker quote volume"].astype(float)
del df["Open time"]
del df["Close time"]
del df["Ignore"]
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
BinanceData.override_feature_config_doc(__pdoc__)
</file>
<file path="data/custom/ccxt.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `CCXTData`."""
import time
import traceback
from functools import wraps, partial
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from ccxt.base.exchange import Exchange as CCXTExchangeT
except ImportError:
CCXTExchangeT = "CCXTExchange"
__all__ = [
"CCXTData",
]
__pdoc__ = {}
class CCXTData(RemoteData):
"""Data class for fetching using CCXT.
See https://github.com/ccxt/ccxt for API.
See `CCXTData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.CCXTData.set_exchange_settings(
... exchange_name="binance",
... populate_=True,
... exchange_config=dict(
... apiKey="YOUR_KEY",
... secret="YOUR_SECRET"
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.CCXTData.pull(
... "BTCUSDT",
... exchange="binance",
... start="2020-01-01",
... end="2021-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.ccxt")
@classmethod
def get_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> dict:
"""`CCXTData.get_custom_settings` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`CCXTData.has_custom_settings` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def get_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`CCXTData.get_custom_setting` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`CCXTData.has_custom_setting` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def resolve_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`CCXTData.resolve_custom_setting` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def set_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> None:
"""`CCXTData.set_custom_settings` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
cls.set_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
exchange: tp.Union[None, str, CCXTExchangeT] = None,
exchange_config: tp.Optional[tp.KwargsLike] = None,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if exchange_config is None:
exchange_config = {}
exchange = cls.resolve_exchange(exchange=exchange, **exchange_config)
all_symbols = []
for symbol in exchange.load_markets():
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def resolve_exchange(
cls,
exchange: tp.Union[None, str, CCXTExchangeT] = None,
**exchange_config,
) -> CCXTExchangeT:
"""Resolve the exchange.
If provided, must be of the type `ccxt.base.exchange.Exchange`.
Otherwise, will be created using `exchange_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ccxt")
import ccxt
exchange = cls.resolve_exchange_setting(exchange, "exchange")
if exchange is None:
exchange = "binance"
if isinstance(exchange, str):
exchange = exchange.lower()
exchange_name = exchange
elif isinstance(exchange, ccxt.Exchange):
exchange_name = type(exchange).__name__
else:
raise ValueError(f"Unknown exchange of type {type(exchange)}")
if exchange_config is None:
exchange_config = {}
has_exchange_config = len(exchange_config) > 0
exchange_config = cls.resolve_exchange_setting(
exchange_config, "exchange_config", merge=True, exchange_name=exchange_name
)
if isinstance(exchange, str):
if not hasattr(ccxt, exchange):
raise ValueError(f"Exchange '{exchange}' not found in CCXT")
exchange = getattr(ccxt, exchange)(exchange_config)
else:
if has_exchange_config:
raise ValueError("Cannot apply config after instantiation of the exchange")
return exchange
@staticmethod
def _find_earliest_date(
fetch_func: tp.Callable,
start: tp.DatetimeLike = 0,
end: tp.DatetimeLike = "now",
tz: tp.TimezoneLike = None,
for_internal_use: bool = False,
) -> tp.Optional[pd.Timestamp]:
"""Find the earliest date using binary search."""
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
fetched_data = fetch_func(start_ts, 1)
if for_internal_use and len(fetched_data) > 0:
return pd.Timestamp(start_ts, unit="ms", tz="utc")
else:
fetched_data = []
if len(fetched_data) == 0 and start != 0:
fetched_data = fetch_func(0, 1)
if for_internal_use and len(fetched_data) > 0:
return pd.Timestamp(0, unit="ms", tz="utc")
if len(fetched_data) == 0:
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
else:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(0, naive_tz=tz, tz="utc"))
start_ts = start_ts - start_ts % 86400000
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
else:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime("now", naive_tz=tz, tz="utc"))
end_ts = end_ts - end_ts % 86400000 + 86400000
start_time = start_ts
end_time = end_ts
while True:
mid_time = (start_time + end_time) // 2
mid_time = mid_time - mid_time % 86400000
if mid_time == start_time:
break
_fetched_data = fetch_func(mid_time, 1)
if len(_fetched_data) == 0:
start_time = mid_time
else:
end_time = mid_time
fetched_data = _fetched_data
if len(fetched_data) > 0:
return pd.Timestamp(fetched_data[0][0], unit="ms", tz="utc")
return None
@classmethod
def find_earliest_date(cls, symbol: str, for_internal_use: bool = False, **kwargs) -> tp.Optional[pd.Timestamp]:
"""Find the earliest date using binary search.
See `CCXTData.fetch_symbol` for arguments."""
return cls._find_earliest_date(
**cls.fetch_symbol(symbol, return_fetch_method=True, **kwargs),
for_internal_use=for_internal_use,
)
@classmethod
def fetch_symbol(
cls,
symbol: str,
exchange: tp.Union[None, str, CCXTExchangeT] = None,
exchange_config: tp.Optional[tp.KwargsLike] = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
find_earliest_date: tp.Optional[bool] = None,
limit: tp.Optional[int] = None,
delay: tp.Optional[float] = None,
retries: tp.Optional[int] = None,
fetch_params: tp.Optional[tp.KwargsLike] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
return_fetch_method: bool = False,
) -> tp.Union[dict, tp.SymbolData]:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from CCXT.
Args:
symbol (str): Symbol.
Symbol can be in the `EXCHANGE:SYMBOL` format, in this case `exchange` argument will be ignored.
exchange (str or object): Exchange identifier or an exchange object.
See `CCXTData.resolve_exchange`.
exchange_config (dict): Exchange config.
See `CCXTData.resolve_exchange`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
find_earliest_date (bool): Whether to find the earliest date using `CCXTData.find_earliest_date`.
limit (int): The maximum number of returned items.
delay (float): Time to sleep after each request (in seconds).
!!! note
Use only if `enableRateLimit` is not set.
retries (int): The number of retries on failure to fetch data.
fetch_params (dict): Exchange-specific keyword arguments passed to `fetch_ohlcv`.
show_progress (bool): Whether to show the progress bar.
pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.
silence_warnings (bool): Whether to silence all warnings.
return_fetch_method (bool): Required by `CCXTData.find_earliest_date`.
For defaults, see `custom.ccxt` in `vectorbtpro._settings.data`.
Global settings can be provided per exchange id using the `exchanges` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ccxt")
import ccxt
exchange = cls.resolve_custom_setting(exchange, "exchange")
if exchange is None and ":" in symbol:
exchange, symbol = symbol.split(":")
if exchange_config is None:
exchange_config = {}
exchange = cls.resolve_exchange(exchange=exchange, **exchange_config)
exchange_name = type(exchange).__name__
start = cls.resolve_exchange_setting(start, "start", exchange_name=exchange_name)
end = cls.resolve_exchange_setting(end, "end", exchange_name=exchange_name)
timeframe = cls.resolve_exchange_setting(timeframe, "timeframe", exchange_name=exchange_name)
tz = cls.resolve_exchange_setting(tz, "tz", exchange_name=exchange_name)
find_earliest_date = cls.resolve_exchange_setting(
find_earliest_date, "find_earliest_date", exchange_name=exchange_name
)
limit = cls.resolve_exchange_setting(limit, "limit", exchange_name=exchange_name)
delay = cls.resolve_exchange_setting(delay, "delay", exchange_name=exchange_name)
retries = cls.resolve_exchange_setting(retries, "retries", exchange_name=exchange_name)
fetch_params = cls.resolve_exchange_setting(
fetch_params, "fetch_params", merge=True, exchange_name=exchange_name
)
show_progress = cls.resolve_exchange_setting(show_progress, "show_progress", exchange_name=exchange_name)
pbar_kwargs = cls.resolve_exchange_setting(pbar_kwargs, "pbar_kwargs", merge=True, exchange_name=exchange_name)
if "bar_id" not in pbar_kwargs:
pbar_kwargs["bar_id"] = "ccxt"
silence_warnings = cls.resolve_exchange_setting(
silence_warnings, "silence_warnings", exchange_name=exchange_name
)
if not exchange.has["fetchOHLCV"]:
raise ValueError(f"Exchange {exchange} does not support OHLCV")
if exchange.has["fetchOHLCV"] == "emulated":
if not silence_warnings:
warn("Using emulated OHLCV candles")
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
if unit == "D":
unit = "d"
elif unit == "W":
unit = "w"
elif unit == "Y":
unit = "y"
timeframe = str(multiplier) + unit
if timeframe not in exchange.timeframes:
raise ValueError(f"Exchange {exchange} does not support {timeframe} timeframe")
def _retry(method):
@wraps(method)
def retry_method(*args, **kwargs):
for i in range(retries):
try:
return method(*args, **kwargs)
except ccxt.NetworkError as e:
if i == retries - 1:
raise e
if not silence_warnings:
warn(traceback.format_exc())
if delay is not None:
time.sleep(delay)
return retry_method
@_retry
def _fetch(_since, _limit):
return exchange.fetch_ohlcv(
symbol,
timeframe=timeframe,
since=_since,
limit=_limit,
params=fetch_params,
)
if return_fetch_method:
return dict(fetch_func=_fetch, start=start, end=end, tz=tz)
# Establish the timestamps
if find_earliest_date and start is not None:
start = cls._find_earliest_date(_fetch, start=start, end=end, tz=tz, for_internal_use=True)
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
else:
start_ts = None
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="UTC"))
else:
end_ts = None
prev_end_ts = None
def _ts_to_str(ts: tp.Optional[int]) -> str:
if ts is None:
return "?"
return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe)
def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool:
if start_ts is not None:
if d[0] < start_ts:
return False
if _prev_end_ts is not None:
if d[0] <= _prev_end_ts:
return False
if end_ts is not None:
if d[0] >= end_ts:
return False
return True
# Iteratively collect the data
data = []
try:
with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar:
pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts)))
while True:
# Fetch the klines for the next timeframe
next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit)
next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data))
# Update the timestamps and the progress bar
if not len(next_data):
break
data += next_data
if start_ts is None:
start_ts = next_data[0][0]
pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0])))
pbar.update()
prev_end_ts = next_data[-1][0]
if end_ts is not None and prev_end_ts >= end_ts:
break
if delay is not None:
time.sleep(delay) # be kind to api
except Exception as e:
if not silence_warnings:
warn(traceback.format_exc())
warn(
f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. "
"Use update() method to fetch missing data."
)
# Convert data to a DataFrame
df = pd.DataFrame(data, columns=["Open time", "Open", "High", "Low", "Close", "Volume"])
df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True)
del df["Open time"]
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/csv.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `CSVData`."""
from pathlib import Path
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"CSVData",
]
__pdoc__ = {}
CSVDataT = tp.TypeVar("CSVDataT", bound="CSVData")
class CSVData(FileData):
"""Data class for fetching CSV data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.csv")
@classmethod
def is_csv_file(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a CSV/TSV file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_file() and ".csv" in path.suffixes:
return True
if path.exists() and path.is_file() and ".tsv" in path.suffixes:
return True
return False
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
return cls.is_csv_file(path)
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = cls.list_paths()
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
tz: tp.TimezoneLike = None,
start_row: tp.Optional[int] = None,
end_row: tp.Optional[int] = None,
header: tp.Optional[tp.MaybeSequence[int]] = None,
index_col: tp.Optional[int] = None,
parse_dates: tp.Optional[bool] = None,
chunk_func: tp.Optional[tp.Callable] = None,
squeeze: tp.Optional[bool] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the CSV file of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
If `path` is None, uses `key` as the path to the CSV file.
start (any): Start datetime.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
end (any): End datetime.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
start_row (int): Start row (inclusive).
Must exclude header rows.
end_row (int): End row (exclusive).
Must exclude header rows.
header (int or sequence of int): See `pd.read_csv`.
index_col (int): See `pd.read_csv`.
If False, will pass None.
parse_dates (bool): See `pd.read_csv`.
chunk_func (callable): Function to select and concatenate chunks from `TextFileReader`.
Gets called only if `iterator` or `chunksize` are set.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
**read_kwargs: Other keyword arguments passed to `pd.read_csv`.
`skiprows` and `nrows` will be automatically calculated based on `start_row` and `end_row`.
When either `start` or `end` is provided, will fetch the entire data first and filter it thereafter.
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for other arguments.
For defaults, see `custom.csv` in `vectorbtpro._settings.data`."""
from pandas.io.parsers import TextFileReader
from pandas.api.types import is_object_dtype
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
tz = cls.resolve_custom_setting(tz, "tz")
start_row = cls.resolve_custom_setting(start_row, "start_row")
if start_row is None:
start_row = 0
end_row = cls.resolve_custom_setting(end_row, "end_row")
header = cls.resolve_custom_setting(header, "header")
index_col = cls.resolve_custom_setting(index_col, "index_col")
if index_col is False:
index_col = None
parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates")
chunk_func = cls.resolve_custom_setting(chunk_func, "chunk_func")
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if path is None:
path = key
if isinstance(header, int):
header = [header]
header_rows = header[-1] + 1
start_row += header_rows
if end_row is not None:
end_row += header_rows
skiprows = range(header_rows, start_row)
if end_row is not None:
nrows = end_row - start_row
else:
nrows = None
sep = read_kwargs.pop("sep", None)
if isinstance(path, (str, Path)):
try:
_path = Path(path)
if _path.suffix.lower() == ".csv":
if sep is None:
sep = ","
if _path.suffix.lower() == ".tsv":
if sep is None:
sep = "\t"
except Exception as e:
pass
if sep is None:
sep = ","
obj = pd.read_csv(
path,
sep=sep,
header=header,
index_col=index_col,
parse_dates=parse_dates,
skiprows=skiprows,
nrows=nrows,
**read_kwargs,
)
if isinstance(obj, TextFileReader):
if chunk_func is None:
obj = pd.concat(list(obj), axis=0)
else:
obj = chunk_func(obj)
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
if index_col is not None and parse_dates and is_object_dtype(obj.index.dtype):
obj.index = pd.to_datetime(obj.index, utc=True)
if tz is not None:
obj.index = obj.index.tz_convert(tz)
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if start is not None or end is not None:
if not isinstance(obj.index, pd.DatetimeIndex):
raise TypeError("Cannot filter index that is not DatetimeIndex")
if obj.index.tz is not None:
if start is not None:
start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=obj.index.tz)
if end is not None:
end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=obj.index.tz)
else:
if start is not None:
start = dt.to_naive_timestamp(start, tz=tz)
if end is not None:
end = dt.to_naive_timestamp(end, tz=tz)
mask = True
if start is not None:
mask &= obj.index >= start
if end is not None:
mask &= obj.index < end
mask_indices = np.flatnonzero(mask)
if len(mask_indices) == 0:
return None
obj = obj.iloc[mask_indices[0] : mask_indices[-1] + 1]
start_row += mask_indices[0]
return obj, dict(last_row=start_row - header_rows + len(obj.index) - 1, tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the CSV file of a feature.
Uses `CSVData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Fetch the CSV file of a symbol.
Uses `CSVData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
returned_kwargs = self.select_returned_kwargs(key)
fetch_kwargs["start_row"] = returned_kwargs["last_row"]
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `CSVData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `CSVData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
</file>
<file path="data/custom/custom.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `CustomData`."""
import fnmatch
import re
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import Data
__all__ = [
"CustomData",
]
__pdoc__ = {}
class CustomData(Data):
"""Data class for fetching custom data."""
_settings_path: tp.SettingsPath = dict(custom=None)
@classmethod
def get_custom_settings(cls, *args, **kwargs) -> dict:
"""`CustomData.get_settings` with `path_id="custom"`."""
return cls.get_settings(*args, path_id="custom", **kwargs)
@classmethod
def has_custom_settings(cls, *args, **kwargs) -> bool:
"""`CustomData.has_settings` with `path_id="custom"`."""
return cls.has_settings(*args, path_id="custom", **kwargs)
@classmethod
def get_custom_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.get_setting` with `path_id="custom"`."""
return cls.get_setting(*args, path_id="custom", **kwargs)
@classmethod
def has_custom_setting(cls, *args, **kwargs) -> bool:
"""`CustomData.has_setting` with `path_id="custom"`."""
return cls.has_setting(*args, path_id="custom", **kwargs)
@classmethod
def resolve_custom_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.resolve_setting` with `path_id="custom"`."""
return cls.resolve_setting(*args, path_id="custom", **kwargs)
@classmethod
def set_custom_settings(cls, *args, **kwargs) -> None:
"""`CustomData.set_settings` with `path_id="custom"`."""
cls.set_settings(*args, path_id="custom", **kwargs)
@staticmethod
def key_match(key: str, pattern: str, use_regex: bool = False):
"""Return whether key matches pattern.
If `use_regex` is True, checks against a regular expression.
Otherwise, checks against a glob-style pattern."""
if use_regex:
return re.match(pattern, key)
return re.match(fnmatch.translate(pattern), key)
</file>
<file path="data/custom/db.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `DBData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.local import LocalData
__all__ = [
"DBData",
]
__pdoc__ = {}
class DBData(LocalData):
"""Data class for fetching database data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.db")
</file>
<file path="data/custom/duckdb.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `DuckDBData`."""
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import key_dict
from vectorbtpro.data.custom.db import DBData
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
try:
if not tp.TYPE_CHECKING:
raise ImportError
from duckdb import DuckDBPyConnection as DuckDBPyConnectionT, DuckDBPyRelation as DuckDBPyRelationT
except ImportError:
DuckDBPyConnectionT = "DuckDBPyConnection"
DuckDBPyRelationT = "DuckDBPyRelation"
__all__ = [
"DuckDBData",
]
__pdoc__ = {}
DuckDBDataT = tp.TypeVar("DuckDBDataT", bound="DuckDBData")
class DuckDBData(DBData):
"""Data class for fetching data using DuckDB.
See `DuckDBData.pull` and `DuckDBData.fetch_key` for arguments.
Usage:
* Set up the connection settings globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.DuckDBData.set_custom_settings(connection="database.duckdb")
```
* Pull tables:
```pycon
>>> data = vbt.DuckDBData.pull(["TABLE1", "TABLE2"])
```
* Rename tables:
```pycon
>>> data = vbt.DuckDBData.pull(
... ["SYMBOL1", "SYMBOL2"],
... table=vbt.key_dict({
... "SYMBOL1": "TABLE1",
... "SYMBOL2": "TABLE2"
... })
... )
```
* Pull queries:
```pycon
>>> data = vbt.DuckDBData.pull(
... ["SYMBOL1", "SYMBOL2"],
... query=vbt.key_dict({
... "SYMBOL1": "SELECT * FROM TABLE1",
... "SYMBOL2": "SELECT * FROM TABLE2"
... })
... )
```
* Pull Parquet files:
```pycon
>>> data = vbt.DuckDBData.pull(
... ["SYMBOL1", "SYMBOL2"],
... read_path=vbt.key_dict({
... "SYMBOL1": "s1.parquet",
... "SYMBOL2": "s2.parquet"
... })
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.duckdb")
@classmethod
def resolve_connection(
cls,
connection: tp.Union[None, str, tp.PathLike, DuckDBPyConnectionT] = None,
read_only: bool = True,
return_meta: bool = False,
**connection_config,
) -> tp.Union[DuckDBPyConnectionT, dict]:
"""Resolve the connection."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("duckdb")
from duckdb import connect, default_connection
connection_meta = {}
connection = cls.resolve_custom_setting(connection, "connection")
if connection_config is None:
connection_config = {}
has_connection_config = len(connection_config) > 0
connection_config["read_only"] = read_only
connection_config = cls.resolve_custom_setting(connection_config, "connection_config", merge=True)
read_only = connection_config.pop("read_only", read_only)
should_close = False
if connection is None:
if len(connection_config) == 0:
connection = default_connection
else:
database = connection_config.pop("database", None)
if "config" in connection_config or len(connection_config) == 0:
connection = connect(database, read_only=read_only, **connection_config)
else:
connection = connect(database, read_only=read_only, config=connection_config)
should_close = True
elif isinstance(connection, (str, Path)):
if "config" in connection_config or len(connection_config) == 0:
connection = connect(str(connection), read_only=read_only, **connection_config)
else:
connection = connect(str(connection), read_only=read_only, config=connection_config)
should_close = True
elif has_connection_config:
raise ValueError("Cannot apply connection_config to already initialized connection")
if return_meta:
connection_meta["connection"] = connection
connection_meta["should_close"] = should_close
return connection_meta
return connection
@classmethod
def list_catalogs(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
incl_system: bool = False,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all catalogs.
Catalogs "system" and "temp" are skipped if `incl_system` is False.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df()
catalogs = []
for catalog in schemata_df["catalog_name"].tolist():
if pattern is not None:
if not cls.key_match(catalog, pattern, use_regex=use_regex):
continue
if not incl_system and catalog == "system":
continue
if not incl_system and catalog == "temp":
continue
catalogs.append(catalog)
if should_close:
connection.close()
if sort:
return sorted(dict.fromkeys(catalogs))
return list(dict.fromkeys(catalogs))
@classmethod
def list_schemas(
cls,
catalog_pattern: tp.Optional[str] = None,
schema_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
catalog: tp.Optional[str] = None,
incl_system: bool = False,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all schemas.
If `catalog` is None, searches for all catalog names in the database and prefixes each schema
with the respective catalog name. If `catalog` is provided, returns the schemas corresponding
to this catalog without a prefix. Schemas "information_schema" and "pg_catalog" are skipped
if `incl_system` is False.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
if catalog is None:
catalogs = cls.list_catalogs(
pattern=catalog_pattern,
use_regex=use_regex,
sort=sort,
incl_system=incl_system,
connection=connection,
connection_config=connection_config,
)
if len(catalogs) == 1:
prefix_catalog = False
else:
prefix_catalog = True
else:
catalogs = [catalog]
prefix_catalog = False
schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df()
schemas = []
for catalog in catalogs:
all_schemas = schemata_df[schemata_df["catalog_name"] == catalog]["schema_name"].tolist()
for schema in all_schemas:
if schema_pattern is not None:
if not cls.key_match(schema, schema_pattern, use_regex=use_regex):
continue
if not incl_system and schema == "information_schema":
continue
if not incl_system and schema == "pg_catalog":
continue
if prefix_catalog:
schema = catalog + ":" + schema
schemas.append(schema)
if should_close:
connection.close()
if sort:
return sorted(dict.fromkeys(schemas))
return list(dict.fromkeys(schemas))
@classmethod
def get_current_schema(
cls,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> str:
"""Get the current schema."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
current_schema = connection.sql("SELECT current_schema()").fetchall()[0][0]
if should_close:
connection.close()
return current_schema
@classmethod
def list_tables(
cls,
*,
catalog_pattern: tp.Optional[str] = None,
schema_pattern: tp.Optional[str] = None,
table_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
catalog: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
incl_system: bool = False,
incl_temporary: bool = False,
incl_views: bool = True,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all tables and views.
If `schema` is None, searches for all schema names in the database and prefixes each table
with the respective catalog and schema name (unless there's only one schema which is the current
schema or `schema` is `current_schema`). If `schema` is provided, returns the tables corresponding
to this schema without a prefix.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against
`schema_pattern` and each table against `table_pattern`."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
if catalog is None:
catalogs = cls.list_catalogs(
pattern=catalog_pattern,
use_regex=use_regex,
sort=sort,
incl_system=incl_system,
connection=connection,
connection_config=connection_config,
)
if catalog_pattern is None and len(catalogs) == 1:
prefix_catalog = False
else:
prefix_catalog = True
else:
catalogs = [catalog]
prefix_catalog = False
current_schema = cls.get_current_schema(
connection=connection,
connection_config=connection_config,
)
if schema is None:
catalogs_schemas = []
for catalog in catalogs:
catalog_schemas = cls.list_schemas(
schema_pattern=schema_pattern,
use_regex=use_regex,
sort=sort,
catalog=catalog,
incl_system=incl_system,
connection=connection,
connection_config=connection_config,
)
for schema in catalog_schemas:
catalogs_schemas.append((catalog, schema))
if len(catalogs_schemas) == 1 and catalogs_schemas[0][1] == current_schema:
prefix_schema = False
else:
prefix_schema = True
else:
if schema == "current_schema":
schema = current_schema
catalogs_schemas = []
for catalog in catalogs:
catalogs_schemas.append((catalog, schema))
prefix_schema = prefix_catalog
tables_df = connection.sql("SELECT * FROM information_schema.tables").df()
tables = []
for catalog, schema in catalogs_schemas:
all_tables = []
all_tables.extend(
tables_df[
(tables_df["table_catalog"] == catalog)
& (tables_df["table_schema"] == schema)
& (tables_df["table_type"] == "BASE TABLE")
]["table_name"].tolist()
)
if incl_temporary:
all_tables.extend(
tables_df[
(tables_df["table_catalog"] == catalog)
& (tables_df["table_schema"] == schema)
& (tables_df["table_type"] == "LOCAL TEMPORARY")
]["table_name"].tolist()
)
if incl_views:
all_tables.extend(
tables_df[
(tables_df["table_catalog"] == catalog)
& (tables_df["table_schema"] == schema)
& (tables_df["table_type"] == "VIEW")
]["table_name"].tolist()
)
for table in all_tables:
if table_pattern is not None:
if not cls.key_match(table, table_pattern, use_regex=use_regex):
continue
if not prefix_catalog and prefix_schema:
table = schema + ":" + table
elif prefix_catalog or prefix_schema:
table = catalog + ":" + schema + ":" + table
tables.append(table)
if should_close:
connection.close()
if sort:
return sorted(dict.fromkeys(tables))
return list(dict.fromkeys(tables))
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
catalog: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
read_path: tp.Optional[tp.PathLike] = None,
read_format: tp.Optional[str] = None,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.Kwargs:
keys_meta = DBData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None:
if cls.has_key_dict(catalog):
raise ValueError("Cannot populate keys if catalog is defined per key")
if cls.has_key_dict(schema):
raise ValueError("Cannot populate keys if schema is defined per key")
if cls.has_key_dict(list_tables_kwargs):
raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key")
if cls.has_key_dict(connection):
raise ValueError("Cannot populate keys if connection is defined per key")
if cls.has_key_dict(connection_config):
raise ValueError("Cannot populate keys if connection_config is defined per key")
if cls.has_key_dict(read_path):
raise ValueError("Cannot populate keys if read_path is defined per key")
if cls.has_key_dict(read_format):
raise ValueError("Cannot populate keys if read_format is defined per key")
if read_path is not None or read_format is not None:
if read_path is None:
read_path = "."
if read_format is not None:
read_format = read_format.lower()
checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format")
keys_meta["keys"] = FileData.list_paths(read_path, extension=read_format)
else:
if list_tables_kwargs is None:
list_tables_kwargs = {}
keys_meta["keys"] = cls.list_tables(
catalog=catalog,
schema=schema,
connection=connection,
connection_config=connection_config,
**list_tables_kwargs,
)
return keys_meta
@classmethod
def pull(
cls: tp.Type[DuckDBDataT],
keys: tp.Union[tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[tp.MaybeFeatures] = None,
symbols: tp.Union[tp.MaybeSymbols] = None,
catalog: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
read_path: tp.Optional[tp.PathLike] = None,
read_format: tp.Optional[str] = None,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
share_connection: tp.Optional[bool] = None,
**kwargs,
) -> DuckDBDataT:
"""Override `vectorbtpro.data.base.Data.pull` to resolve and share the connection among the keys
and use the table names available in the database in case no keys were provided."""
if share_connection is None:
if not cls.has_key_dict(connection) and not cls.has_key_dict(connection_config):
share_connection = True
else:
share_connection = False
if share_connection:
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
else:
should_close = False
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
catalog=catalog,
schema=schema,
list_tables_kwargs=list_tables_kwargs,
read_path=read_path,
read_format=read_format,
connection=connection,
connection_config=connection_config,
)
keys = keys_meta["keys"]
if isinstance(read_path, key_dict):
new_read_path = read_path.copy()
else:
new_read_path = key_dict()
if isinstance(keys, dict):
new_keys = {}
for k, v in keys.items():
if isinstance(k, Path):
new_k = FileData.path_to_key(k)
new_read_path[new_k] = k
k = new_k
new_keys[k] = v
keys = new_keys
elif cls.has_multiple_keys(keys):
new_keys = []
for k in keys:
if isinstance(k, Path):
new_k = FileData.path_to_key(k)
new_read_path[new_k] = k
k = new_k
new_keys.append(k)
keys = new_keys
else:
if isinstance(keys, Path):
new_keys = FileData.path_to_key(keys)
new_read_path[new_keys] = keys
keys = new_keys
if len(new_read_path) > 0:
read_path = new_read_path
keys_are_features = keys_meta["keys_are_features"]
outputs = super(DBData, cls).pull(
keys,
keys_are_features=keys_are_features,
catalog=catalog,
schema=schema,
read_path=read_path,
read_format=read_format,
connection=connection,
connection_config=connection_config,
**kwargs,
)
if should_close:
connection.close()
return outputs
@classmethod
def format_write_option(cls, option: tp.Any) -> str:
"""Format a write option."""
if isinstance(option, str):
return f"'{option}'"
if isinstance(option, (tuple, list)):
return "(" + ", ".join(map(str, option)) + ")"
if isinstance(option, dict):
return "{" + ", ".join(map(lambda y: f"{y[0]}: {cls.format_write_option(y[1])}", option.items())) + "}"
return f"{option}"
@classmethod
def format_write_options(cls, options: tp.Union[str, dict]) -> str:
"""Format write options."""
if isinstance(options, str):
return options
new_options = []
for k, v in options.items():
new_options.append(f"{k.upper()} {cls.format_write_option(v)}")
return ", ".join(new_options)
@classmethod
def format_read_option(cls, option: tp.Any) -> str:
"""Format a read option."""
if isinstance(option, str):
return f"'{option}'"
if isinstance(option, (tuple, list)):
return "[" + ", ".join(map(cls.format_read_option, option)) + "]"
if isinstance(option, dict):
return "{" + ", ".join(map(lambda y: f"'{y[0]}': {cls.format_read_option(y[1])}", option.items())) + "}"
return f"{option}"
@classmethod
def format_read_options(cls, options: tp.Union[str, dict]) -> str:
"""Format read options."""
if isinstance(options, str):
return options
new_options = []
for k, v in options.items():
new_options.append(f"{k.lower()}={cls.format_read_option(v)}")
return ", ".join(new_options)
@classmethod
def fetch_key(
cls,
key: str,
table: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
catalog: tp.Optional[str] = None,
read_path: tp.Optional[tp.PathLike] = None,
read_format: tp.Optional[str] = None,
read_options: tp.Union[None, str, dict] = None,
query: tp.Union[None, str, DuckDBPyRelationT] = None,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
start: tp.Optional[tp.Any] = None,
end: tp.Optional[tp.Any] = None,
align_dates: tp.Optional[bool] = None,
parse_dates: tp.Union[None, bool, tp.Sequence[str]] = None,
to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None,
tz: tp.TimezoneLike = None,
index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None,
squeeze: tp.Optional[bool] = None,
df_kwargs: tp.KwargsLike = None,
**sql_kwargs,
) -> tp.KeyData:
"""Fetch a feature or symbol from a DuckDB database.
Can use a table name (which defaults to the key) or a custom query.
Args:
key (str): Feature or symbol.
If `table` and `query` are both None, becomes the table name.
Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored.
table (str): Table name.
Cannot be used together with `file` or `query`.
schema (str): Schema name.
Cannot be used together with `file` or `query`.
catalog (str): Catalog name.
Cannot be used together with ``file` or query`.
read_path (path_like): Path to a file to read.
Cannot be used together with `table`, `schema`, `catalog`, or `query`.
read_format (str): Format of the file to read.
Allowed values are "csv", "parquet", and "json".
Requires `read_path` to be set.
read_options (str or dict): Options used to read the file.
Requires `read_path` and `read_format` to be set.
Uses `DuckDBData.format_read_options` to transform a dictionary to a string.
query (str or DuckDBPyRelation): Custom query.
Cannot be used together with `catalog`, `schema`, and `table`.
connection (str or object): See `DuckDBData.resolve_connection`.
connection_config (dict): See `DuckDBData.resolve_connection`.
start (any): Start datetime (if datetime index) or any other start value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
Cannot be used together with `query`. Include the condition into the query.
end (any): End datetime (if datetime index) or any other end value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
Cannot be used together with `query`. Include the condition into the query.
align_dates (bool): Whether to align `start` and `end` to the timezone of the index.
Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index.
parse_dates (bool or sequence of str): See `DuckDBData.prepare_dt`.
to_utc (bool, str, or sequence of str): See `DuckDBData.prepare_dt`.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
index_col (int, str, or list): One or more columns that should become the index.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
df_kwargs (dict): Keyword arguments passed to `relation.df` to convert a relation to a DataFrame.
**sql_kwargs: Other keyword arguments passed to `connection.execute` to run a SQL query.
For defaults, see `custom.duckdb` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("duckdb")
from duckdb import DuckDBPyRelation
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
if catalog is not None and query is not None:
raise ValueError("Cannot use catalog and query together")
if schema is not None and query is not None:
raise ValueError("Cannot use schema and query together")
if table is not None and query is not None:
raise ValueError("Cannot use table and query together")
if read_path is not None and query is not None:
raise ValueError("Cannot use read_path and query together")
if read_path is not None and (catalog is not None or schema is not None or table is not None):
raise ValueError("Cannot use read_path and catalog/schema/table together")
if table is None and read_path is None and read_format is None and query is None:
if ":" in key:
key_parts = key.split(":")
if len(key_parts) == 2:
schema, table = key_parts
else:
catalog, schema, table = key_parts
else:
table = key
if read_format is not None:
read_format = read_format.lower()
checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format")
if read_path is None:
read_path = (Path(".") / key).with_suffix("." + read_format)
else:
if read_path is not None:
if isinstance(read_path, str):
read_path = Path(read_path)
if read_path.suffix[1:] in ["csv", "parquet", "json"]:
read_format = read_path.suffix[1:]
if read_path is not None:
if isinstance(read_path, Path):
read_path = str(read_path)
read_path = cls.format_read_option(read_path)
if read_options is not None:
if read_format is None:
raise ValueError("Must provide read_format for read_options")
read_options = cls.format_read_options(read_options)
catalog = cls.resolve_custom_setting(catalog, "catalog")
schema = cls.resolve_custom_setting(schema, "schema")
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
align_dates = cls.resolve_custom_setting(align_dates, "align_dates")
parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates")
to_utc = cls.resolve_custom_setting(to_utc, "to_utc")
tz = cls.resolve_custom_setting(tz, "tz")
index_col = cls.resolve_custom_setting(index_col, "index_col")
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True)
sql_kwargs = cls.resolve_custom_setting(sql_kwargs, "sql_kwargs", merge=True)
if query is None:
if read_path is not None:
if read_options is not None:
query = f"SELECT * FROM read_{read_format}({read_path}, {read_options})"
elif read_format is not None:
query = f"SELECT * FROM read_{read_format}({read_path})"
else:
query = f"SELECT * FROM {read_path}"
else:
if catalog is not None:
if schema is None:
schema = cls.get_current_schema(
connection=connection,
connection_config=connection_config,
)
query = f'SELECT * FROM "{catalog}"."{schema}"."{table}"'
elif schema is not None:
query = f'SELECT * FROM "{schema}"."{table}"'
else:
query = f'SELECT * FROM "{table}"'
if start is not None or end is not None:
if index_col is None:
raise ValueError("Must provide index column for filtering by start and end")
if not checks.is_int(index_col) and not isinstance(index_col, str):
raise ValueError("Index column must be integer or string for filtering by start and end")
if checks.is_int(index_col) or align_dates:
metadata_df = connection.sql("DESCRIBE " + query + " LIMIT 1").df()
else:
metadata_df = None
if checks.is_int(index_col):
index_name = metadata_df["column_name"].tolist()[0]
else:
index_name = index_col
if parse_dates:
index_column_type = metadata_df[metadata_df["column_name"] == index_name]["column_type"].item()
if index_column_type in (
"TIMESTAMP_NS",
"TIMESTAMP_MS",
"TIMESTAMP_S",
"TIMESTAMP",
"DATETIME",
):
if start is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start = dt.to_naive_datetime(start)
else:
start = dt.to_naive_datetime(start, tz=tz)
if end is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end = dt.to_naive_datetime(end)
else:
end = dt.to_naive_datetime(end, tz=tz)
elif index_column_type in ("TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"):
if start is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
else:
start = dt.to_tzaware_datetime(start, naive_tz=tz)
if end is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
else:
end = dt.to_tzaware_datetime(end, naive_tz=tz)
if start is not None and end is not None:
query += f' WHERE "{index_name}" >= $start AND "{index_name}" < $end'
elif start is not None:
query += f' WHERE "{index_name}" >= $start'
elif end is not None:
query += f' WHERE "{index_name}" < $end'
params = sql_kwargs.get("params", None)
if params is None:
params = {}
else:
params = dict(params)
if not isinstance(params, dict):
raise ValueError("Parameters must be a dictionary for filtering by start and end")
if start is not None:
if "start" in params:
raise ValueError("Start is already in params")
params["start"] = start
if end is not None:
if "end" in params:
raise ValueError("End is already in params")
params["end"] = end
sql_kwargs["params"] = params
else:
if start is not None:
raise ValueError("Start cannot be applied to custom queries")
if end is not None:
raise ValueError("End cannot be applied to custom queries")
if not isinstance(query, DuckDBPyRelation):
relation = connection.sql(query, **sql_kwargs)
else:
relation = query
obj = relation.df(**df_kwargs)
if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index):
if index_col is not None:
if checks.is_int(index_col):
keys = obj.columns[index_col]
elif isinstance(index_col, str):
keys = index_col
else:
keys = []
for col in index_col:
if checks.is_int(col):
keys.append(obj.columns[col])
else:
keys.append(col)
obj = obj.set_index(keys)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
obj = cls.prepare_dt(obj, to_utc=to_utc, parse_dates=parse_dates)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
if should_close:
connection.close()
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData:
"""Fetch the table of a feature.
Uses `DuckDBData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData:
"""Fetch the table for a symbol.
Uses `DuckDBData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: str, from_last_index: tp.Optional[bool] = None, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
pre_kwargs = merge_dicts(fetch_kwargs, kwargs)
if from_last_index is None:
if pre_kwargs.get("query", None) is not None:
from_last_index = False
else:
from_last_index = True
if from_last_index:
fetch_kwargs["start"] = self.select_last_index(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if self.feature_oriented:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: str, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `DuckDBData.update_key`."""
return self.update_key(feature, **kwargs)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `DuckDBData.update_key`."""
return self.update_key(symbol, **kwargs)
</file>
<file path="data/custom/feather.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FeatherData`."""
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"FeatherData",
]
__pdoc__ = {}
FeatherDataT = tp.TypeVar("FeatherDataT", bound="FeatherData")
class FeatherData(FileData):
"""Data class for fetching Feather data using PyArrow."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.feather")
@classmethod
def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]:
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_dir():
path = path / "*.feather"
return cls.match_path(path, **match_path_kwargs)
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = "*.feather"
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
tz: tp.TimezoneLike = None,
index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None,
squeeze: tp.Optional[bool] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the Feather file of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
If `path` is None, uses `key` as the path to the Feather file.
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
index_col (int, str, or sequence): Position(s) or name(s) of column(s) that should become the index.
Will only apply if the fetched object has a default index.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
**read_kwargs: Other keyword arguments passed to `pd.read_feather`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_feather.html for other arguments.
For defaults, see `custom.feather` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pyarrow")
tz = cls.resolve_custom_setting(tz, "tz")
index_col = cls.resolve_custom_setting(index_col, "index_col")
if index_col is False:
index_col = None
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if path is None:
path = key
obj = pd.read_feather(path, **read_kwargs)
if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index):
if index_col is not None:
if checks.is_int(index_col):
keys = obj.columns[index_col]
elif isinstance(index_col, str):
keys = index_col
else:
keys = []
for col in index_col:
if checks.is_int(col):
keys.append(obj.columns[col])
else:
keys.append(col)
obj = obj.set_index(keys)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the Feather file of a feature.
Uses `FeatherData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Fetch the Feather file of a symbol.
Uses `FeatherData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `FeatherData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `FeatherData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
</file>
<file path="data/custom/file.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FileData`."""
import re
from glob import glob
from pathlib import Path
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import key_dict
from vectorbtpro.data.custom.local import LocalData
from vectorbtpro.utils import checks
__all__ = [
"FileData",
]
__pdoc__ = {}
FileDataT = tp.TypeVar("FileDataT", bound="FileData")
class FileData(LocalData):
"""Data class for fetching file data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.file")
@classmethod
def is_dir_match(cls, path: tp.PathLike) -> bool:
"""Return whether a directory is a valid match."""
return False
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
"""Return whether a file is a valid match."""
return True
@classmethod
def match_path(
cls,
path: tp.PathLike,
match_regex: tp.Optional[str] = None,
sort_paths: bool = True,
recursive: bool = True,
extension: tp.Optional[str] = None,
**kwargs,
) -> tp.List[Path]:
"""Get the list of all paths matching a path.
If `FileData.is_dir_match` returns True for a directory, it gets returned as-is.
Otherwise, iterates through all files in that directory and invokes `FileData.is_file_match`.
If a pattern was provided, these methods aren't invoked."""
if not isinstance(path, Path):
path = Path(path)
if path.exists():
if path.is_dir() and not cls.is_dir_match(path):
sub_paths = []
for p in path.iterdir():
if p.is_dir() and cls.is_dir_match(p):
sub_paths.append(p)
if p.is_file() and cls.is_file_match(p):
if extension is None or p.suffix == "." + extension:
sub_paths.append(p)
else:
sub_paths = [path]
else:
sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)])
if match_regex is not None:
sub_paths = [p for p in sub_paths if re.match(match_regex, str(p))]
if sort_paths:
sub_paths = sorted(sub_paths)
return sub_paths
@classmethod
def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]:
"""List all features or symbols under a path."""
return cls.match_path(path, **match_path_kwargs)
@classmethod
def path_to_key(cls, path: tp.PathLike, **kwargs) -> str:
"""Convert a path into a feature or symbol."""
return Path(path).stem
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
return LocalData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
@classmethod
def pull(
cls: tp.Type[FileDataT],
keys: tp.Union[tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[tp.MaybeFeatures] = None,
symbols: tp.Union[tp.MaybeSymbols] = None,
paths: tp.Any = None,
match_paths: tp.Optional[bool] = None,
match_regex: tp.Optional[str] = None,
sort_paths: tp.Optional[bool] = None,
match_path_kwargs: tp.KwargsLike = None,
path_to_key_kwargs: tp.KwargsLike = None,
**kwargs,
) -> FileDataT:
"""Override `vectorbtpro.data.base.Data.pull` to take care of paths.
Use either features, symbols, or `paths` to specify the path to one or multiple files.
Allowed are paths in a string or `pathlib.Path` format, or string expressions accepted by `glob.glob`.
Set `match_paths` to False to not parse paths and behave like a regular
`vectorbtpro.data.base.Data` instance.
For defaults, see `custom.local` in `vectorbtpro._settings.data`.
"""
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
paths=paths,
)
keys = keys_meta["keys"]
keys_are_features = keys_meta["keys_are_features"]
dict_type = keys_meta["dict_type"]
match_paths = cls.resolve_custom_setting(match_paths, "match_paths")
match_regex = cls.resolve_custom_setting(match_regex, "match_regex")
sort_paths = cls.resolve_custom_setting(sort_paths, "sort_paths")
if match_paths:
sync = False
if paths is None:
paths = keys
sync = True
elif keys is None:
sync = True
if paths is None:
if keys_are_features:
raise ValueError("At least features or paths must be set")
else:
raise ValueError("At least symbols or paths must be set")
if match_path_kwargs is None:
match_path_kwargs = {}
if path_to_key_kwargs is None:
path_to_key_kwargs = {}
single_key = False
if isinstance(keys, (str, Path)):
# Single key
keys = [keys]
single_key = True
single_path = False
if isinstance(paths, (str, Path)):
# Single path
paths = [paths]
single_path = True
if sync:
single_key = True
cls.check_dict_type(paths, "paths", dict_type=dict_type)
if isinstance(paths, key_dict):
# Dict of path per key
if sync:
keys = list(paths.keys())
elif len(keys) != len(paths):
if keys_are_features:
raise ValueError("The number of features must be equal to the number of matched paths")
else:
raise ValueError("The number of symbols must be equal to the number of matched paths")
elif checks.is_iterable(paths) or checks.is_sequence(paths):
# Multiple paths
matched_paths = [
p
for sub_path in paths
for p in cls.match_path(
sub_path,
match_regex=match_regex,
sort_paths=sort_paths,
**match_path_kwargs,
)
]
if len(matched_paths) == 0:
raise FileNotFoundError(f"No paths could be matched with {paths}")
if sync:
keys = []
paths = key_dict()
for p in matched_paths:
s = cls.path_to_key(p, **path_to_key_kwargs)
keys.append(s)
paths[s] = p
elif len(keys) != len(matched_paths):
if keys_are_features:
raise ValueError("The number of features must be equal to the number of matched paths")
else:
raise ValueError("The number of symbols must be equal to the number of matched paths")
else:
paths = key_dict({s: matched_paths[i] for i, s in enumerate(keys)})
if len(matched_paths) == 1 and single_path:
paths = matched_paths[0]
else:
raise TypeError(f"Path '{paths}' is not supported")
if len(keys) == 1 and single_key:
keys = keys[0]
return super(FileData, cls).pull(
keys,
keys_are_features=keys_are_features,
path=paths,
**kwargs,
)
</file>
<file path="data/custom/finpy.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FinPyData`."""
from itertools import product
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
try:
if not tp.TYPE_CHECKING:
raise ImportError
from findatapy.market import Market as MarketT
from findatapy.util import ConfigManager as ConfigManagerT
except ImportError:
MarketT = "Market"
ConfigManagerT = "ConfigManager"
__all__ = [
"FinPyData",
]
FinPyDataT = tp.TypeVar("FinPyDataT", bound="FinPyData")
class FinPyData(RemoteData):
"""Data class for fetching using findatapy.
See https://github.com/cuemacro/findatapy for API.
See `FinPyData.fetch_symbol` for arguments.
Usage:
* Pull data (keyword argument format):
```pycon
>>> data = vbt.FinPyData.pull(
... "EURUSD",
... start="14 June 2016",
... end="15 June 2016",
... timeframe="tick",
... category="fx",
... fields=["bid", "ask"],
... data_source="dukascopy"
... )
```
* Pull data (string format):
```pycon
>>> data = vbt.FinPyData.pull(
... "fx.dukascopy.tick.NYC.EURUSD.bid,ask",
... start="14 June 2016",
... end="15 June 2016",
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.finpy")
@classmethod
def resolve_market(
cls,
market: tp.Optional[MarketT] = None,
**market_config,
) -> MarketT:
"""Resolve the market.
If provided, must be of the type `findatapy.market.market.Market`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("findatapy")
from findatapy.market import Market, MarketDataGenerator
market = cls.resolve_custom_setting(market, "market")
if market_config is None:
market_config = {}
has_market_config = len(market_config) > 0
market_config = cls.resolve_custom_setting(market_config, "market_config", merge=True)
if "market_data_generator" not in market_config:
market_config["market_data_generator"] = MarketDataGenerator()
if market is None:
market = Market(**market_config)
elif has_market_config:
raise ValueError("Cannot apply market_config to already initialized market")
return market
@classmethod
def resolve_config_manager(
cls,
config_manager: tp.Optional[ConfigManagerT] = None,
**config_manager_config,
) -> MarketT:
"""Resolve the config manager.
If provided, must be of the type `findatapy.util.ConfigManager`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("findatapy")
from findatapy.util import ConfigManager
config_manager = cls.resolve_custom_setting(config_manager, "config_manager")
if config_manager_config is None:
config_manager_config = {}
has_config_manager_config = len(config_manager_config) > 0
config_manager_config = cls.resolve_custom_setting(config_manager_config, "config_manager_config", merge=True)
if config_manager is None:
config_manager = ConfigManager().get_instance(**config_manager_config)
elif has_config_manager_config:
raise ValueError("Cannot apply config_manager_config to already initialized config_manager")
return config_manager
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
config_manager: tp.Optional[ConfigManagerT] = None,
config_manager_config: tp.KwargsLike = None,
category: tp.Optional[tp.MaybeList[str]] = None,
data_source: tp.Optional[tp.MaybeList[str]] = None,
freq: tp.Optional[tp.MaybeList[str]] = None,
cut: tp.Optional[tp.MaybeList[str]] = None,
tickers: tp.Optional[tp.MaybeList[str]] = None,
dict_filter: tp.DictLike = None,
smart_group: bool = False,
return_fields: tp.Optional[tp.MaybeList[str]] = None,
combine_parts: bool = True,
) -> tp.List[str]:
"""List all symbols.
Passes most arguments to `findatapy.util.ConfigManager.free_form_tickers_regex_query`.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if config_manager_config is None:
config_manager_config = {}
config_manager = cls.resolve_config_manager(config_manager=config_manager, **config_manager_config)
if dict_filter is None:
dict_filter = {}
def_ret_fields = ["category", "data_source", "freq", "cut", "tickers"]
if return_fields is None:
ret_fields = def_ret_fields
elif isinstance(return_fields, str):
if return_fields.lower() == "all":
ret_fields = def_ret_fields + ["fields"]
else:
ret_fields = [return_fields]
else:
ret_fields = return_fields
df = config_manager.free_form_tickers_regex_query(
category=category,
data_source=data_source,
freq=freq,
cut=cut,
tickers=tickers,
dict_filter=dict_filter,
smart_group=smart_group,
ret_fields=ret_fields,
)
all_symbols = []
for _, row in df.iterrows():
parts = []
if "category" in row.index:
parts.append(row.loc["category"])
if "data_source" in row.index:
parts.append(row.loc["data_source"])
if "freq" in row.index:
parts.append(row.loc["freq"])
if "cut" in row.index:
parts.append(row.loc["cut"])
if "tickers" in row.index:
parts.append(row.loc["tickers"])
if "fields" in row.index:
parts.append(row.loc["fields"])
if combine_parts:
split_parts = [part.split(",") for part in parts]
combinations = list(product(*split_parts))
else:
combinations = [parts]
for symbol in [".".join(combination) for combination in combinations]:
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def fetch_symbol(
cls,
symbol: str,
market: tp.Optional[MarketT] = None,
market_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
**request_kwargs,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from findatapy.
Args:
symbol (str): Symbol.
Also accepts the format such as "fx.bloomberg.daily.NYC.EURUSD.close".
The fields `freq`, `cut`, `tickers`, and `fields` here are optional.
market (findatapy.market.market.Market): Market.
See `FinPyData.resolve_market`.
market_config (dict): Client config.
See `FinPyData.resolve_market`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
**request_kwargs: Other keyword arguments passed to `findatapy.market.marketdatarequest.MarketDataRequest`.
For defaults, see `custom.finpy` in `vectorbtpro._settings.data`.
Global settings can be provided per exchange id using the `exchanges` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("findatapy")
from findatapy.market import MarketDataRequest
if market_config is None:
market_config = {}
market = cls.resolve_market(market=market, **market_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
request_kwargs = cls.resolve_custom_setting(request_kwargs, "request_kwargs", merge=True)
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "s":
unit = "second"
freq = timeframe
elif unit == "m":
unit = "minute"
freq = timeframe
elif unit == "h":
unit = "hourly"
freq = timeframe
elif unit == "D":
unit = "daily"
freq = timeframe
elif unit == "W":
unit = "weekly"
freq = timeframe
elif unit == "M":
unit = "monthly"
freq = timeframe
elif unit == "Q":
unit = "quarterly"
freq = timeframe
elif unit == "Y":
unit = "annually"
freq = timeframe
else:
freq = None
if "resample" in request_kwargs:
freq = request_kwargs["resample"]
if start is not None:
start = dt.to_naive_datetime(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
if end is not None:
end = dt.to_naive_datetime(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
if "md_request" in request_kwargs:
md_request = request_kwargs["md_request"]
elif "md_request_df" in request_kwargs:
md_request = market.create_md_request_from_dataframe(
md_request_df=request_kwargs["md_request_df"],
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
elif "md_request_str" in request_kwargs:
md_request = market.create_md_request_from_str(
md_request_str=request_kwargs["md_request_str"],
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
elif "md_request_dict" in request_kwargs:
md_request = market.create_md_request_from_dict(
md_request_dict=request_kwargs["md_request_dict"],
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
elif symbol.count(".") >= 2:
md_request = market.create_md_request_from_str(
md_request_str=symbol,
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
else:
md_request = MarketDataRequest(
tickers=symbol,
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
df = market.fetch_market(md_request=md_request)
if df is None:
return None
if isinstance(md_request.tickers, str):
ticker = md_request.tickers
elif len(md_request.tickers) == 1:
ticker = md_request.tickers[0]
else:
ticker = None
if ticker is not None:
df.columns = df.columns.map(lambda x: x.replace(ticker + ".", ""))
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/gbm_ohlc.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `GBMOHLCData`."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import broadcast_array_to
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.ohlcv import nb as ohlcv_nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import substitute_templates
__all__ = [
"GBMOHLCData",
]
__pdoc__ = {}
class GBMOHLCData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_1d_nb`
and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.gbm_ohlc")
@classmethod
def generate_symbol(
cls,
symbol: tp.Symbol,
index: tp.Index,
n_ticks: tp.Optional[tp.ArrayLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
dt: tp.Optional[float] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.SymbolData:
"""Generate a symbol.
Args:
symbol (hashable): Symbol.
index (pd.Index): Pandas index.
n_ticks (int or array_like): Number of ticks per bar.
Flexible argument. Can be a template with a context containing `symbol` and `index`.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
dt (float): Time change (one period of time).
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
template_context (dict): Context used to substitute templates.
For defaults, see `custom.gbm` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`.
"""
n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks")
template_context = merge_dicts(dict(symbol=symbol, index=index), template_context)
n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks")
n_ticks = broadcast_array_to(n_ticks, len(index))
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
dt = cls.resolve_custom_setting(dt, "dt")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_gbm_data_1d_nb, jitted)
ticks = func(
np.sum(n_ticks),
start_value=start_value,
mean=mean,
std=std,
dt=dt,
)
func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted)
out = func(ticks, n_ticks)
return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"])
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[symbol]["Open"].iloc[-1]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, start_value=start_value, **kwargs)
</file>
<file path="data/custom/gbm.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `GBMData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_1d_array
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
__all__ = [
"GBMData",
]
__pdoc__ = {}
class GBMData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.gbm")
@classmethod
def generate_key(
cls,
key: tp.Key,
index: tp.Index,
columns: tp.Union[tp.Hashable, tp.IndexLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
dt: tp.Optional[float] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
**kwargs,
) -> tp.KeyData:
"""Generate a feature or symbol.
Args:
key (hashable): Feature or symbol.
index (pd.Index): Pandas index.
columns (hashable or index_like): Column names.
Provide a single value (hashable) to make a Series.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
dt (float): Time change (one period of time).
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
For defaults, see `custom.gbm` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per feature/symbol using
`vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally
`vectorbtpro.data.base.key_dict`.
"""
if checks.is_hashable(columns):
columns = [columns]
make_series = True
else:
make_series = False
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
dt = cls.resolve_custom_setting(dt, "dt")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_gbm_data_nb, jitted)
out = func(
(len(index), len(columns)),
start_value=to_1d_array(start_value),
mean=to_1d_array(mean),
std=to_1d_array(std),
dt=to_1d_array(dt),
)
if make_series:
return pd.Series(out[:, 0], index=index, name=columns[0])
return pd.DataFrame(out, index=index, columns=columns)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
fetch_kwargs = self.select_fetch_kwargs(key)
fetch_kwargs["start"] = self.select_last_index(key)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[key].iloc[-2]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, start_value=start_value, **kwargs)
return self.fetch_symbol(key, start_value=start_value, **kwargs)
</file>
<file path="data/custom/hdf.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `HDFData`."""
import re
from glob import glob
from pathlib import Path, PurePath
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names
__all__ = [
"HDFData",
]
__pdoc__ = {}
class HDFPathNotFoundError(Exception):
"""Gets raised if the path to an HDF file could not be found."""
pass
class HDFKeyNotFoundError(Exception):
"""Gets raised if the key to an HDF object could not be found."""
pass
HDFDataT = tp.TypeVar("HDFDataT", bound="HDFData")
class HDFData(FileData):
"""Data class for fetching HDF data using PyTables."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.hdf")
@classmethod
def is_hdf_file(cls, path: tp.PathLike) -> bool:
"""Return whether the path is an HDF file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_file() and ".hdf" in path.suffixes:
return True
if path.exists() and path.is_file() and ".hdf5" in path.suffixes:
return True
if path.exists() and path.is_file() and ".h5" in path.suffixes:
return True
return False
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
return cls.is_hdf_file(path)
@classmethod
def split_hdf_path(
cls,
path: tp.PathLike,
key: tp.Optional[str] = None,
_full_path: tp.Optional[Path] = None,
) -> tp.Tuple[Path, tp.Optional[str]]:
"""Split the path to an HDF object into the path to the file and the key."""
path = Path(path)
if _full_path is None:
_full_path = path
if path.exists():
if path.is_dir():
raise HDFPathNotFoundError(f"No HDF files could be matched with {_full_path}")
return path, key
new_path = path.parent
if key is None:
new_key = path.name
else:
new_key = str(Path(path.name) / key)
return cls.split_hdf_path(new_path, new_key, _full_path=_full_path)
@classmethod
def match_path(
cls,
path: tp.PathLike,
match_regex: tp.Optional[str] = None,
sort_paths: bool = True,
recursive: bool = True,
**kwargs,
) -> tp.List[Path]:
"""Override `FileData.match_path` to return a list of HDF paths
(path to file + key) matching a path."""
path = Path(path)
if path.exists():
if path.is_dir() and not cls.is_dir_match(path):
sub_paths = []
for p in path.iterdir():
if p.is_dir() and cls.is_dir_match(p):
sub_paths.append(p)
if p.is_file() and cls.is_file_match(p):
sub_paths.append(p)
key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)]
else:
with pd.HDFStore(str(path), mode="r") as store:
keys = [k[1:] for k in store.keys()]
key_paths = [path / k for k in keys]
else:
try:
file_path, key = cls.split_hdf_path(path)
with pd.HDFStore(str(file_path), mode="r") as store:
keys = [k[1:] for k in store.keys()]
if key is None:
key_paths = [file_path / k for k in keys]
elif key in keys:
key_paths = [file_path / key]
else:
matching_keys = []
for k in keys:
if k.startswith(key) or PurePath("/" + str(k)).match("/" + str(key)):
matching_keys.append(k)
if len(matching_keys) == 0:
raise HDFKeyNotFoundError(f"No HDF keys could be matched with {key}")
key_paths = [file_path / k for k in matching_keys]
except HDFPathNotFoundError:
sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)])
if len(sub_paths) == 0 and re.match(r".+\..+", str(path)):
base_path = None
base_ended = False
key_path = None
for part in path.parts:
part = Path(part)
if base_ended:
if key_path is None:
key_path = part
else:
key_path /= part
else:
if re.match(r".+\..+", str(part)):
base_ended = True
if base_path is None:
base_path = part
else:
base_path /= part
sub_paths = list([Path(p) for p in glob(str(base_path), recursive=recursive)])
if key_path is not None:
sub_paths = [p / key_path for p in sub_paths]
key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)]
if match_regex is not None:
key_paths = [p for p in key_paths if re.match(match_regex, str(p))]
if sort_paths:
key_paths = sorted(key_paths)
return key_paths
@classmethod
def path_to_key(cls, path: tp.PathLike, **kwargs) -> str:
return Path(path).name
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = cls.list_paths()
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
tz: tp.TimezoneLike = None,
start_row: tp.Optional[int] = None,
end_row: tp.Optional[int] = None,
chunk_func: tp.Optional[tp.Callable] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the HDF object of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
Will be resolved with `HDFData.split_hdf_path`.
If `path` is None, uses `key` as the path to the HDF file.
start (any): Start datetime.
Will extract the object's index and compare the index to the date.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
!!! note
Can only be used if the object was saved in the table format!
end (any): End datetime.
Will extract the object's index and compare the index to the date.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
!!! note
Can only be used if the object was saved in the table format!
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
start_row (int): Start row (inclusive).
Will use it when querying index as well.
end_row (int): End row (exclusive).
Will use it when querying index as well.
chunk_func (callable): Function to select and concatenate chunks from `TableIterator`.
Gets called only if `iterator` or `chunksize` are set.
**read_kwargs: Other keyword arguments passed to `pd.read_hdf`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_hdf.html for other arguments.
For defaults, see `custom.hdf` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("tables")
from pandas.io.pytables import TableIterator
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
tz = cls.resolve_custom_setting(tz, "tz")
start_row = cls.resolve_custom_setting(start_row, "start_row")
if start_row is None:
start_row = 0
end_row = cls.resolve_custom_setting(end_row, "end_row")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if path is None:
path = key
path = Path(path)
file_path, file_key = cls.split_hdf_path(path)
if file_key is not None:
key = file_key
if start is not None or end is not None:
hdf_store_arg_names = get_func_arg_names(pd.HDFStore.__init__)
hdf_store_kwargs = dict()
for k, v in read_kwargs.items():
if k in hdf_store_arg_names:
hdf_store_kwargs[k] = v
with pd.HDFStore(str(file_path), mode="r", **hdf_store_kwargs) as store:
index = store.select_column(key, "index", start=start_row, stop=end_row)
if not isinstance(index, pd.Index):
index = pd.Index(index)
if not isinstance(index, pd.DatetimeIndex):
raise TypeError("Cannot filter index that is not DatetimeIndex")
if tz is None:
tz = index.tz
if index.tz is not None:
if start is not None:
start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=index.tz)
if end is not None:
end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=index.tz)
else:
if start is not None:
start = dt.to_naive_timestamp(start, tz=tz)
if end is not None:
end = dt.to_naive_timestamp(end, tz=tz)
mask = True
if start is not None:
mask &= index >= start
if end is not None:
mask &= index < end
mask_indices = np.flatnonzero(mask)
if len(mask_indices) == 0:
return None
start_row += mask_indices[0]
end_row = start_row + mask_indices[-1] - mask_indices[0] + 1
obj = pd.read_hdf(file_path, key=key, start=start_row, stop=end_row, **read_kwargs)
if isinstance(obj, TableIterator):
if chunk_func is None:
obj = pd.concat(list(obj), axis=0)
else:
obj = chunk_func(obj)
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
return obj, dict(last_row=start_row + len(obj.index) - 1, tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the HDF object of a feature.
Uses `HDFData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Load the HDF object for a symbol.
Uses `HDFData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
returned_kwargs = self.select_returned_kwargs(key)
fetch_kwargs["start_row"] = returned_kwargs["last_row"]
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `HDFData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `HDFData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
</file>
<file path="data/custom/local.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `LocalData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.custom import CustomData
__all__ = [
"LocalData",
]
__pdoc__ = {}
class LocalData(CustomData):
"""Data class for fetching local data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.local")
</file>
<file path="data/custom/ndl.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `NDLData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"NDLData",
]
__pdoc__ = {}
NDLDataT = tp.TypeVar("NDLDataT", bound="NDLData")
class NDLData(RemoteData):
"""Data class for fetching from Nasdaq Data Link.
See https://github.com/Nasdaq/data-link-python for API.
See `NDLData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.NDLData.set_custom_settings(
... api_key="YOUR_KEY"
... )
```
* Pull a dataset:
```pycon
>>> data = vbt.NDLData.pull(
... "FRED/GDP",
... start="2001-12-31",
... end="2005-12-31"
... )
```
* Pull a datatable:
```pycon
>>> data = vbt.NDLData.pull(
... "MER/F1",
... data_format="datatable",
... compnumber="39102",
... paginate=True
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.ndl")
@classmethod
def fetch_symbol(
cls,
symbol: str,
api_key: tp.Optional[str] = None,
data_format: tp.Optional[str] = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
tz: tp.TimezoneLike = None,
column_indices: tp.Optional[tp.MaybeIterable[int]] = None,
**params,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Nasdaq Data Link.
Args:
symbol (str): Symbol.
api_key (str): API key.
data_format (str): Data format.
Supported are "dataset" and "datatable".
start (any): Retrieve data rows on and after the specified start date.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): Retrieve data rows up to and including the specified end date.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
column_indices (int or iterable): Request one or more specific columns.
Column 0 is the date column and is always returned. Data begins at column 1.
**params: Keyword arguments sent as field/value params to Nasdaq Data Link with no interference.
For defaults, see `custom.ndl` in `vectorbtpro._settings.data`.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("nasdaqdatalink")
import nasdaqdatalink
api_key = cls.resolve_custom_setting(api_key, "api_key")
data_format = cls.resolve_custom_setting(data_format, "data_format")
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
tz = cls.resolve_custom_setting(tz, "tz")
column_indices = cls.resolve_custom_setting(column_indices, "column_indices")
if column_indices is not None:
if isinstance(column_indices, int):
dataset = symbol + "." + str(column_indices)
else:
dataset = [symbol + "." + str(index) for index in column_indices]
else:
dataset = symbol
params = cls.resolve_custom_setting(params, "params", merge=True)
# Establish the timestamps
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start_date = pd.Timestamp(start).isoformat()
if "start_date" not in params:
params["start_date"] = start_date
else:
start_date = None
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end_date = pd.Timestamp(end).isoformat()
if "end_date" not in params:
params["end_date"] = end_date
else:
end_date = None
# Collect and format the data
if data_format.lower() == "dataset":
df = nasdaqdatalink.get(
dataset,
api_key=api_key,
**params,
)
else:
df = nasdaqdatalink.get_table(
dataset,
api_key=api_key,
**params,
)
new_columns = []
for c in df.columns:
new_c = c
if isinstance(symbol, str):
new_c = new_c.replace(symbol + " - ", "")
if new_c == "Last":
new_c = "Close"
new_columns.append(new_c)
df = df.rename(columns=dict(zip(df.columns, new_columns)))
if df.index.name == "None":
df.index.name = None
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
if isinstance(df.index, pd.DatetimeIndex) and not df.empty:
if start is not None:
start = dt.to_timestamp(start, tz=df.index.tz)
if df.index[0] < start:
df = df[df.index >= start]
if end is not None:
end = dt.to_timestamp(end, tz=df.index.tz)
if df.index[-1] >= end:
df = df[df.index < end]
return df, dict(tz=tz)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/parquet.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `ParquetData`."""
import re
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"ParquetData",
]
__pdoc__ = {}
ParquetDataT = tp.TypeVar("ParquetDataT", bound="ParquetData")
class ParquetData(FileData):
"""Data class for fetching Parquet data using PyArrow or FastParquet."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.parquet")
@classmethod
def is_parquet_file(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a Parquet file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_file() and ".parquet" in path.suffixes:
return True
return False
@classmethod
def is_parquet_group_dir(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a directory that is a group of Parquet partitions.
!!! note
Assumes the Hive partitioning scheme."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_dir():
partition_regex = r"^(.+)=(.+)"
if re.match(partition_regex, path.name):
for p in path.iterdir():
if cls.is_parquet_group_dir(p) or cls.is_parquet_file(p):
return True
return False
@classmethod
def is_parquet_dir(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a directory that is a group itself or
contains groups of Parquet partitions."""
if cls.is_parquet_group_dir(path):
return True
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_dir():
for p in path.iterdir():
if cls.is_parquet_group_dir(p):
return True
return False
@classmethod
def is_dir_match(cls, path: tp.PathLike) -> bool:
return cls.is_parquet_dir(path)
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
return cls.is_parquet_file(path)
@classmethod
def list_partition_cols(cls, path: tp.PathLike) -> tp.List[str]:
"""List partitioning columns under a path.
!!! note
Assumes the Hive partitioning scheme."""
if not isinstance(path, Path):
path = Path(path)
partition_cols = []
found_last_level = False
while not found_last_level:
found_new_level = False
for p in path.iterdir():
if cls.is_parquet_group_dir(p):
partition_cols.append(p.name.split("=")[0])
path = p
found_new_level = True
break
if not found_new_level:
found_last_level = True
return partition_cols
@classmethod
def is_default_partition_col(cls, level: str) -> bool:
"""Return whether a partitioning column is a default partitioning column."""
return re.match(r"^(\bgroup\b)|(group_\d+)", level) is not None
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = cls.list_paths()
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
tz: tp.TimezoneLike = None,
squeeze: tp.Optional[bool] = None,
keep_partition_cols: tp.Optional[bool] = None,
engine: tp.Optional[str] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the Parquet file of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
If `path` is None, uses `key` as the path to the Parquet file.
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
keep_partition_cols (bool): Whether to return partitioning columns (if any).
If None, will remove any partitioning column that is "group" or "group_{index}".
Retrieves the list of partitioning columns with `ParquetData.list_partition_cols`.
engine (str): See `pd.read_parquet`.
**read_kwargs: Other keyword arguments passed to `pd.read_parquet`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html for other arguments.
For defaults, see `custom.parquet` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import, assert_can_import_any
tz = cls.resolve_custom_setting(tz, "tz")
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
keep_partition_cols = cls.resolve_custom_setting(keep_partition_cols, "keep_partition_cols")
engine = cls.resolve_custom_setting(engine, "engine")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if engine == "pyarrow":
assert_can_import("pyarrow")
elif engine == "fastparquet":
assert_can_import("fastparquet")
elif engine == "auto":
assert_can_import_any("pyarrow", "fastparquet")
else:
raise ValueError(f"Invalid engine: '{engine}'")
if path is None:
path = key
obj = pd.read_parquet(path, engine=engine, **read_kwargs)
if keep_partition_cols in (None, False):
if cls.is_parquet_dir(path):
drop_columns = []
partition_cols = cls.list_partition_cols(path)
for col in obj.columns:
if col in partition_cols:
if keep_partition_cols is False or cls.is_default_partition_col(col):
drop_columns.append(col)
obj = obj.drop(drop_columns, axis=1)
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the Parquet file of a feature.
Uses `ParquetData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Fetch the Parquet file of a symbol.
Uses `ParquetData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `ParquetData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `ParquetData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
</file>
<file path="data/custom/polygon.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `PolygonData`."""
import time
import traceback
from functools import wraps, partial
import pandas as pd
import requests
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from polygon import RESTClient as PolygonClientT
except ImportError:
PolygonClientT = "PolygonClient"
__all__ = [
"PolygonData",
]
PolygonDataT = tp.TypeVar("PolygonDataT", bound="PolygonData")
class PolygonData(RemoteData):
"""Data class for fetching from Polygon.
See https://github.com/polygon-io/client-python for API.
See `PolygonData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally:
```pycon
>>> from vectorbtpro import *
>>> vbt.PolygonData.set_custom_settings(
... client_config=dict(
... api_key="YOUR_KEY"
... )
... )
```
* Pull stock data:
```pycon
>>> data = vbt.PolygonData.pull(
... "AAPL",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
* Pull crypto data:
```pycon
>>> data = vbt.PolygonData.pull(
... "X:BTCUSD",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.polygon")
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
client: tp.Optional[PolygonClientT] = None,
client_config: tp.DictLike = None,
**list_tickers_kwargs,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.
For supported keyword arguments, see `polygon.RESTClient.list_tickers`."""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
all_symbols = []
for ticker in client.list_tickers(**list_tickers_kwargs):
symbol = ticker.ticker
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def resolve_client(cls, client: tp.Optional[PolygonClientT] = None, **client_config) -> PolygonClientT:
"""Resolve the client.
If provided, must be of the type `polygon.rest.RESTClient`.
Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("polygon")
from polygon import RESTClient
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = RESTClient(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[PolygonClientT] = None,
client_config: tp.DictLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
adjusted: tp.Optional[bool] = None,
limit: tp.Optional[int] = None,
params: tp.KwargsLike = None,
delay: tp.Optional[float] = None,
retries: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Polygon.
Args:
symbol (str): Symbol.
Supports the following APIs:
* Stocks and equities
* Currencies - symbol must have the prefix `C:`
* Crypto - symbol must have the prefix `X:`
client (polygon.rest.RESTClient): Client.
See `PolygonData.resolve_client`.
client_config (dict): Client config.
See `PolygonData.resolve_client`.
start (any): The start of the aggregate time window.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): The end of the aggregate time window.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
adjusted (str): Whether the results are adjusted for splits.
By default, results are adjusted.
Set this to False to get results that are NOT adjusted for splits.
limit (int): Limits the number of base aggregates queried to create the aggregate results.
Max 50000 and Default 5000.
params (dict): Any additional query params.
delay (float): Time to sleep after each request (in seconds).
retries (int): The number of retries on failure to fetch data.
show_progress (bool): Whether to show the progress bar.
pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.
silence_warnings (bool): Whether to silence all warnings.
For defaults, see `custom.polygon` in `vectorbtpro._settings.data`.
!!! note
If you're using a free plan that has an API rate limit of several requests per minute,
make sure to set `delay` to a higher number, such as 12 (which makes 5 requests per minute).
"""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
adjusted = cls.resolve_custom_setting(adjusted, "adjusted")
limit = cls.resolve_custom_setting(limit, "limit")
params = cls.resolve_custom_setting(params, "params", merge=True)
delay = cls.resolve_custom_setting(delay, "delay")
retries = cls.resolve_custom_setting(retries, "retries")
show_progress = cls.resolve_custom_setting(show_progress, "show_progress")
pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True)
if "bar_id" not in pbar_kwargs:
pbar_kwargs["bar_id"] = "polygon"
silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings")
# Resolve the timeframe
if not isinstance(timeframe, str):
raise ValueError(f"Invalid timeframe: '{timeframe}'")
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "m":
unit = "minute"
elif unit == "h":
unit = "hour"
elif unit == "D":
unit = "day"
elif unit == "W":
unit = "week"
elif unit == "M":
unit = "month"
elif unit == "Q":
unit = "quarter"
elif unit == "Y":
unit = "year"
# Establish the timestamps
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
else:
start_ts = None
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
else:
end_ts = None
prev_end_ts = None
def _retry(method):
@wraps(method)
def retry_method(*args, **kwargs):
for i in range(retries):
try:
return method(*args, **kwargs)
except requests.exceptions.HTTPError as e:
if isinstance(e, requests.exceptions.HTTPError) and e.response.status_code == 429:
if not silence_warnings:
warn(traceback.format_exc())
# Polygon.io API rate limit is per minute
warn("Waiting 1 minute...")
time.sleep(60)
else:
raise e
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
if i == retries - 1:
raise e
if not silence_warnings:
warn(traceback.format_exc())
if delay is not None:
time.sleep(delay)
return retry_method
def _postprocess(agg):
return dict(
o=agg.open,
h=agg.high,
l=agg.low,
c=agg.close,
v=agg.volume,
vw=agg.vwap,
t=agg.timestamp,
n=agg.transactions,
)
@_retry
def _fetch(_start_ts, _limit):
return list(
map(
_postprocess,
client.get_aggs(
ticker=symbol,
multiplier=multiplier,
timespan=unit,
from_=_start_ts,
to=end_ts,
adjusted=adjusted,
sort="asc",
limit=_limit,
params=params,
raw=False,
),
)
)
def _ts_to_str(ts: tp.Optional[int]) -> str:
if ts is None:
return "?"
return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe)
def _filter_func(d: tp.Dict, _prev_end_ts: tp.Optional[int] = None) -> bool:
if start_ts is not None:
if d["t"] < start_ts:
return False
if _prev_end_ts is not None:
if d["t"] <= _prev_end_ts:
return False
if end_ts is not None:
if d["t"] >= end_ts:
return False
return True
# Iteratively collect the data
data = []
try:
with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar:
pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts)))
while True:
# Fetch the klines for the next timeframe
next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit)
next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data))
# Update the timestamps and the progress bar
if not len(next_data):
break
data += next_data
if start_ts is None:
start_ts = next_data[0]["t"]
pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1]["t"])))
pbar.update()
prev_end_ts = next_data[-1]["t"]
if end_ts is not None and prev_end_ts >= end_ts:
break
if delay is not None:
time.sleep(delay) # be kind to api
except Exception as e:
if not silence_warnings:
warn(traceback.format_exc())
warn(
f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. "
"Use update() method to fetch missing data."
)
df = pd.DataFrame(data)
df = df[["t", "o", "h", "l", "c", "v", "n", "vw"]]
df = df.rename(
columns={
"t": "Open time",
"o": "Open",
"h": "High",
"l": "Low",
"c": "Close",
"v": "Volume",
"n": "Trade count",
"vw": "VWAP",
}
)
df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True)
del df["Open time"]
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
if "Trade count" in df.columns:
df["Trade count"] = df["Trade count"].astype(int, errors="ignore")
if "VWAP" in df.columns:
df["VWAP"] = df["VWAP"].astype(float)
return df, dict(tz=tz, freq=timeframe)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/random_ohlc.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RandomOHLCData`."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import broadcast_array_to
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.ohlcv import nb as ohlcv_nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import substitute_templates
__all__ = [
"RandomOHLCData",
]
__pdoc__ = {}
class RandomOHLCData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_1d_nb`
and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.random_ohlc")
@classmethod
def generate_symbol(
cls,
symbol: tp.Symbol,
index: tp.Index,
n_ticks: tp.Optional[tp.ArrayLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
symmetric: tp.Optional[bool] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.SymbolData:
"""Generate a symbol.
Args:
symbol (hashable): Symbol.
index (pd.Index): Pandas index.
n_ticks (int or array_like): Number of ticks per bar.
Flexible argument. Can be a template with a context containing `symbol` and `index`.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones.
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
template_context (dict): Template context.
For defaults, see `custom.random_ohlc` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`.
"""
n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks")
template_context = merge_dicts(dict(symbol=symbol, index=index), template_context)
n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks")
n_ticks = broadcast_array_to(n_ticks, len(index))
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
symmetric = cls.resolve_custom_setting(symmetric, "symmetric")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_random_data_1d_nb, jitted)
ticks = func(np.sum(n_ticks), start_value=start_value, mean=mean, std=std, symmetric=symmetric)
func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted)
out = func(ticks, n_ticks)
return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"])
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[symbol]["Open"].iloc[-1]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, start_value=start_value, **kwargs)
</file>
<file path="data/custom/random.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RandomData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_1d_array
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
__all__ = [
"RandomData",
]
__pdoc__ = {}
class RandomData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.random")
@classmethod
def generate_key(
cls,
key: tp.Key,
index: tp.Index,
columns: tp.Union[tp.Hashable, tp.IndexLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
symmetric: tp.Optional[bool] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
**kwargs,
) -> tp.KeyData:
"""Generate a feature or symbol.
Args:
key (hashable): Feature or symbol.
index (pd.Index): Pandas index.
columns (hashable or index_like): Column names.
Provide a single value (hashable) to make a Series.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones.
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
For defaults, see `custom.random` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per feature/symbol using
`vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally
`vectorbtpro.data.base.key_dict`.
"""
if checks.is_hashable(columns):
columns = [columns]
make_series = True
else:
make_series = False
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
symmetric = cls.resolve_custom_setting(symmetric, "symmetric")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_random_data_nb, jitted)
out = func(
(len(index), len(columns)),
start_value=to_1d_array(start_value),
mean=to_1d_array(mean),
std=to_1d_array(std),
symmetric=to_1d_array(symmetric),
)
if make_series:
return pd.Series(out[:, 0], index=index, name=columns[0])
return pd.DataFrame(out, index=index, columns=columns)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
fetch_kwargs = self.select_fetch_kwargs(key)
fetch_kwargs["start"] = self.select_last_index(key)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[key].iloc[-2]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, start_value=start_value, **kwargs)
return self.fetch_symbol(key, start_value=start_value, **kwargs)
</file>
<file path="data/custom/remote.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RemoteData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.custom import CustomData
__all__ = [
"RemoteData",
]
__pdoc__ = {}
class RemoteData(CustomData):
"""Data class for fetching remote data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.remote")
</file>
<file path="data/custom/sql.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `SQLData`."""
from typing import Iterator
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.db import DBData
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
try:
if not tp.TYPE_CHECKING:
raise ImportError
from sqlalchemy import Engine as EngineT, Selectable as SelectableT, Table as TableT
except ImportError:
EngineT = "Engine"
SelectableT = "Selectable"
TableT = "Table"
__all__ = [
"SQLData",
]
__pdoc__ = {}
SQLDataT = tp.TypeVar("SQLDataT", bound="SQLData")
class SQLData(DBData):
"""Data class for fetching data from a database using SQLAlchemy.
See https://www.sqlalchemy.org/ for the SQLAlchemy's API.
See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for the read method.
See `SQLData.pull` and `SQLData.fetch_key` for arguments.
Usage:
* Set up the engine settings globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.SQLData.set_engine_settings(
... engine_name="postgresql",
... populate_=True,
... engine="postgresql+psycopg2://...",
... engine_config=dict(),
... schema="public"
... )
```
* Pull tables:
```pycon
>>> data = vbt.SQLData.pull(
... ["TABLE1", "TABLE2"],
... engine="postgresql",
... start="2020-01-01",
... end="2021-01-01"
... )
```
* Pull queries:
```pycon
>>> data = vbt.SQLData.pull(
... ["SYMBOL1", "SYMBOL2"],
... query=vbt.key_dict({
... "SYMBOL1": "SELECT * FROM TABLE1",
... "SYMBOL2": "SELECT * FROM TABLE2"
... }),
... engine="postgresql"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.sql")
@classmethod
def get_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> dict:
"""`SQLData.get_custom_settings` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`SQLData.has_custom_settings` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def get_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`SQLData.get_custom_setting` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`SQLData.has_custom_setting` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def resolve_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`SQLData.resolve_custom_setting` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def set_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> None:
"""`SQLData.set_custom_settings` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
cls.set_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def resolve_engine(
cls,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
return_meta: bool = False,
**engine_config,
) -> tp.Union[EngineT, dict]:
"""Resolve the engine.
Argument `engine` can be
1) an object of the type `sqlalchemy.engine.base.Engine`,
2) a URL of the engine as a string, which will be used to create an engine with
`sqlalchemy.engine.create.create_engine` and `engine_config` passed as keyword arguments
(you should not include `url` in the `engine_config`), or
3) an engine name, which is the name of a sub-config with engine settings under `custom.sql.engines`
in `vectorbtpro._settings.data`. Such a sub-config can then contain the actual engine as an object or a URL.
Argument `engine_name` can be provided instead of `engine`, or also together with `engine`
to pull other settings from a sub-config. URLs can also be used as engine names, but not the
other way around."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import create_engine
if engine is None and engine_name is None:
engine_name = cls.resolve_engine_setting(engine_name, "engine_name")
if engine_name is not None:
engine = cls.resolve_engine_setting(engine, "engine", engine_name=engine_name)
if engine is None:
raise ValueError("Must provide engine or URL (via engine argument)")
else:
engine = cls.resolve_engine_setting(engine, "engine")
if engine is None:
raise ValueError("Must provide engine or URL (via engine argument)")
if isinstance(engine, str):
engine_name = engine
else:
engine_name = None
if engine_name is not None:
if cls.has_engine_setting("engine", engine_name=engine_name, sub_path_only=True):
engine = cls.get_engine_setting("engine", engine_name=engine_name, sub_path_only=True)
has_engine_config = len(engine_config) > 0
engine_config = cls.resolve_engine_setting(engine_config, "engine_config", merge=True, engine_name=engine_name)
if isinstance(engine, str):
if engine.startswith("duckdb:"):
assert_can_import("duckdb_engine")
engine = create_engine(engine, **engine_config)
should_dispose = True
else:
if has_engine_config:
raise ValueError("Cannot apply engine_config to initialized created engine")
should_dispose = False
if return_meta:
return dict(
engine=engine,
engine_name=engine_name,
should_dispose=should_dispose,
)
return engine
@classmethod
def list_schemas(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
**kwargs,
) -> tp.List[str]:
"""List all schemas.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.
Keyword arguments `**kwargs` are passed to `inspector.get_schema_names`.
If `dispose_engine` is None, disposes the engine if it wasn't provided."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
inspector = inspect(engine)
all_schemas = inspector.get_schema_names(**kwargs)
schemas = []
for schema in all_schemas:
if pattern is not None:
if not cls.key_match(schema, pattern, use_regex=use_regex):
continue
if schema == "information_schema":
continue
schemas.append(schema)
if dispose_engine:
engine.dispose()
if sort:
return sorted(dict.fromkeys(schemas))
return list(dict.fromkeys(schemas))
@classmethod
def list_tables(
cls,
*,
schema_pattern: tp.Optional[str] = None,
table_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
schema: tp.Optional[str] = None,
incl_views: bool = True,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
**kwargs,
) -> tp.List[str]:
"""List all tables and views.
If `schema` is None, searches for all schema names in the database and prefixes each table
with the respective schema name (unless there's only one schema "main"). If `schema` is False,
sets the schema to None. If `schema` is provided, returns the tables corresponding to this
schema without a prefix.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against
`schema_pattern` and each table against `table_pattern`.
Keyword arguments `**kwargs` are passed to `inspector.get_table_names`.
If `dispose_engine` is None, disposes the engine if it wasn't provided."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name)
if schema is None:
schemas = cls.list_schemas(
pattern=schema_pattern,
use_regex=use_regex,
sort=sort,
engine=engine,
engine_name=engine_name,
**kwargs,
)
if len(schemas) == 0:
schemas = [None]
prefix_schema = False
elif len(schemas) == 1 and schemas[0] == "main":
prefix_schema = False
else:
prefix_schema = True
elif schema is False:
schemas = [None]
prefix_schema = False
else:
schemas = [schema]
prefix_schema = False
inspector = inspect(engine)
tables = []
for schema in schemas:
all_tables = inspector.get_table_names(schema, **kwargs)
if incl_views:
try:
all_tables += inspector.get_view_names(schema, **kwargs)
except NotImplementedError as e:
pass
try:
all_tables += inspector.get_materialized_view_names(schema, **kwargs)
except NotImplementedError as e:
pass
for table in all_tables:
if table_pattern is not None:
if not cls.key_match(table, table_pattern, use_regex=use_regex):
continue
if prefix_schema and schema is not None:
table = str(schema) + ":" + table
tables.append(table)
if dispose_engine:
engine.dispose()
if sort:
return sorted(dict.fromkeys(tables))
return list(dict.fromkeys(tables))
@classmethod
def has_schema(
cls,
schema: str,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> bool:
"""Check whether the database has a schema."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
return inspect(engine).has_schema(schema)
@classmethod
def create_schema(
cls,
schema: str,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> None:
"""Create a schema if it doesn't exist yet."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy.schema import CreateSchema
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
if not cls.has_schema(schema, engine=engine, engine_name=engine_name):
with engine.connect() as connection:
connection.execute(CreateSchema(schema))
connection.commit()
@classmethod
def has_table(
cls,
table: str,
schema: tp.Optional[str] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> bool:
"""Check whether the database has a table."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
return inspect(engine).has_table(table, schema=schema)
@classmethod
def get_table_relation(
cls,
table: str,
schema: tp.Optional[str] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> TableT:
"""Get table relation."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import MetaData
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name)
metadata_obj = MetaData()
metadata_obj.reflect(bind=engine, schema=schema, only=[table], views=True)
if schema is not None and schema + "." + table in metadata_obj.tables:
return metadata_obj.tables[schema + "." + table]
return metadata_obj.tables[table]
@classmethod
def get_last_row_number(
cls,
table: str,
schema: tp.Optional[str] = None,
row_number_column: tp.Optional[str] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> TableT:
"""Get last row number."""
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
row_number_column = cls.resolve_engine_setting(
row_number_column,
"row_number_column",
engine_name=engine_name,
)
table_relation = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name)
table_column_names = []
for column in table_relation.columns:
table_column_names.append(column.name)
if row_number_column not in table_column_names:
raise ValueError(f"Row number column '{row_number_column}' not found")
query = (
table_relation.select()
.with_only_columns(table_relation.columns.get(row_number_column))
.order_by(table_relation.columns.get(row_number_column).desc())
.limit(1)
)
with engine.connect() as connection:
results = connection.execute(query)
last_row_number = results.first()[0]
connection.commit()
return last_row_number
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> tp.Kwargs:
keys_meta = DBData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None:
if cls.has_key_dict(schema):
raise ValueError("Cannot populate keys if schema is defined per key")
if cls.has_key_dict(list_tables_kwargs):
raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key")
if cls.has_key_dict(engine):
raise ValueError("Cannot populate keys if engine is defined per key")
if cls.has_key_dict(engine_name):
raise ValueError("Cannot populate keys if engine_name is defined per key")
if cls.has_key_dict(engine_config):
raise ValueError("Cannot populate keys if engine_config is defined per key")
if list_tables_kwargs is None:
list_tables_kwargs = {}
keys_meta["keys"] = cls.list_tables(
schema=schema,
engine=engine,
engine_name=engine_name,
engine_config=engine_config,
**list_tables_kwargs,
)
return keys_meta
@classmethod
def pull(
cls: tp.Type[SQLDataT],
keys: tp.Union[tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[tp.MaybeFeatures] = None,
symbols: tp.Union[tp.MaybeSymbols] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
share_engine: tp.Optional[bool] = None,
**kwargs,
) -> SQLDataT:
"""Override `vectorbtpro.data.base.Data.pull` to resolve and share the engine among the keys
and use the table names available in the database in case no keys were provided."""
if share_engine is None:
if (
not cls.has_key_dict(engine)
and not cls.has_key_dict(engine_name)
and not cls.has_key_dict(engine_config)
):
share_engine = True
else:
share_engine = False
if share_engine:
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
else:
engine_name = None
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
schema=schema,
list_tables_kwargs=list_tables_kwargs,
engine=engine,
engine_name=engine_name,
engine_config=engine_config,
)
keys = keys_meta["keys"]
keys_are_features = keys_meta["keys_are_features"]
outputs = super(DBData, cls).pull(
keys,
keys_are_features=keys_are_features,
schema=schema,
engine=engine,
engine_name=engine_name,
engine_config=engine_config,
dispose_engine=False if share_engine else dispose_engine,
**kwargs,
)
if share_engine and dispose_engine:
engine.dispose()
return outputs
@classmethod
def fetch_key(
cls,
key: str,
table: tp.Union[None, str, TableT] = None,
schema: tp.Optional[str] = None,
query: tp.Union[None, str, SelectableT] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
start: tp.Optional[tp.Any] = None,
end: tp.Optional[tp.Any] = None,
align_dates: tp.Optional[bool] = None,
parse_dates: tp.Union[None, bool, tp.List[tp.IntStr], tp.Dict[tp.IntStr, tp.Any]] = None,
to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None,
tz: tp.TimezoneLike = None,
start_row: tp.Optional[int] = None,
end_row: tp.Optional[int] = None,
keep_row_number: tp.Optional[bool] = None,
row_number_column: tp.Optional[str] = None,
index_col: tp.Union[None, bool, tp.MaybeList[tp.IntStr]] = None,
columns: tp.Optional[tp.MaybeList[tp.IntStr]] = None,
dtype: tp.Union[None, tp.DTypeLike, tp.Dict[tp.IntStr, tp.DTypeLike]] = None,
chunksize: tp.Optional[int] = None,
chunk_func: tp.Optional[tp.Callable] = None,
squeeze: tp.Optional[bool] = None,
**read_sql_kwargs,
) -> tp.KeyData:
"""Fetch a feature or symbol from a SQL database.
Can use a table name (which defaults to the key) or a custom query.
Args:
key (str): Feature or symbol.
If `table` and `query` are both None, becomes the table name.
Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored.
table (str or Table): Table name or actual object.
Cannot be used together with `query`.
schema (str): Schema.
Cannot be used together with `query`.
query (str or Selectable): Custom query.
Cannot be used together with `table` and `schema`.
engine (str or object): See `SQLData.resolve_engine`.
engine_name (str): See `SQLData.resolve_engine`.
engine_config (dict): See `SQLData.resolve_engine`.
dispose_engine (bool): See `SQLData.resolve_engine`.
start (any): Start datetime (if datetime index) or any other start value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
If the index is a multi-index, start value must be a tuple.
Cannot be used together with `query`. Include the condition into the query.
end (any): End datetime (if datetime index) or any other end value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
If the index is a multi-index, end value must be a tuple.
Cannot be used together with `query`. Include the condition into the query.
align_dates (bool): Whether to align `start` and `end` to the timezone of the index.
Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index.
parse_dates (bool, list, or dict): Whether to parse dates and how to do it.
If `query` is not used, will get mapped into column names. Otherwise,
usage of integers is not allowed and column names directly must be used.
If enabled, will also try to parse the datetime columns that couldn't be parsed
by Pandas after the object has been fetched.
For dict format, see `pd.read_sql_query`.
to_utc (bool, str, or sequence of str): See `SQLData.prepare_dt`.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
start_row (int): Start row.
Table must contain the column defined in `row_number_column`.
Cannot be used together with `query`. Include the condition into the query.
end_row (int): End row.
Table must contain the column defined in `row_number_column`.
Cannot be used together with `query`. Include the condition into the query.
keep_row_number (bool): Whether to return the column defined in `row_number_column`.
row_number_column (str): Name of the column with row numbers.
index_col (int, str, or list): One or more columns that should become the index.
If `query` is not used, will get mapped into column names. Otherwise,
usage of integers is not allowed and column names directly must be used.
columns (int, str, or list): One or more columns to select.
Will get mapped into column names. Cannot be used together with `query`.
dtype (dtype_like or dict): Data type of each column.
If `query` is not used, will get mapped into column names. Otherwise,
usage of integers is not allowed and column names directly must be used.
For dict format, see `pd.read_sql_query`.
chunksize (int): See `pd.read_sql_query`.
chunk_func (callable): Function to select and concatenate chunks from `Iterator`.
Gets called only if `chunksize` is set.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
**read_sql_kwargs: Other keyword arguments passed to `pd.read_sql_query`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for other arguments.
For defaults, see `custom.sql` in `vectorbtpro._settings.data`.
Global settings can be provided per engine name using the `engines` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import Selectable, Select, FromClause, and_, text
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
if table is not None and query is not None:
raise ValueError("Must provide either table name or query, not both")
if schema is not None and query is not None:
raise ValueError("Schema cannot be applied to custom queries")
if table is None and query is None:
if ":" in key:
schema, table = key.split(":")
else:
table = key
start = cls.resolve_engine_setting(start, "start", engine_name=engine_name)
end = cls.resolve_engine_setting(end, "end", engine_name=engine_name)
align_dates = cls.resolve_engine_setting(align_dates, "align_dates", engine_name=engine_name)
parse_dates = cls.resolve_engine_setting(parse_dates, "parse_dates", engine_name=engine_name)
to_utc = cls.resolve_engine_setting(to_utc, "to_utc", engine_name=engine_name)
tz = cls.resolve_engine_setting(tz, "tz", engine_name=engine_name)
start_row = cls.resolve_engine_setting(start_row, "start_row", engine_name=engine_name)
end_row = cls.resolve_engine_setting(end_row, "end_row", engine_name=engine_name)
keep_row_number = cls.resolve_engine_setting(keep_row_number, "keep_row_number", engine_name=engine_name)
row_number_column = cls.resolve_engine_setting(row_number_column, "row_number_column", engine_name=engine_name)
index_col = cls.resolve_engine_setting(index_col, "index_col", engine_name=engine_name)
columns = cls.resolve_engine_setting(columns, "columns", engine_name=engine_name)
dtype = cls.resolve_engine_setting(dtype, "dtype", engine_name=engine_name)
chunksize = cls.resolve_engine_setting(chunksize, "chunksize", engine_name=engine_name)
chunk_func = cls.resolve_engine_setting(chunk_func, "chunk_func", engine_name=engine_name)
squeeze = cls.resolve_engine_setting(squeeze, "squeeze", engine_name=engine_name)
read_sql_kwargs = cls.resolve_engine_setting(
read_sql_kwargs, "read_sql_kwargs", merge=True, engine_name=engine_name
)
if query is None or isinstance(query, (Selectable, FromClause)):
if query is None:
if isinstance(table, str):
table = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name)
else:
table = query
table_column_names = []
for column in table.columns:
table_column_names.append(column.name)
def _resolve_columns(c):
if checks.is_int(c):
c = table_column_names[int(c)]
elif not isinstance(c, str):
new_c = []
for _c in c:
if checks.is_int(_c):
new_c.append(table_column_names[int(_c)])
else:
if _c not in table_column_names:
for __c in table_column_names:
if _c.lower() == __c.lower():
_c = __c
break
new_c.append(_c)
c = new_c
else:
if c not in table_column_names:
for _c in table_column_names:
if c.lower() == _c.lower():
return _c
return c
if index_col is False:
index_col = None
if index_col is not None:
index_col = _resolve_columns(index_col)
if isinstance(index_col, str):
index_col = [index_col]
if columns is not None:
columns = _resolve_columns(columns)
if isinstance(columns, str):
columns = [columns]
if parse_dates is not None:
if not isinstance(parse_dates, bool):
if isinstance(parse_dates, dict):
parse_dates = dict(zip(_resolve_columns(parse_dates.keys()), parse_dates.values()))
else:
parse_dates = _resolve_columns(parse_dates)
if isinstance(parse_dates, str):
parse_dates = [parse_dates]
if dtype is not None:
if isinstance(dtype, dict):
dtype = dict(zip(_resolve_columns(dtype.keys()), dtype.values()))
if not isinstance(table, Select):
query = table.select()
else:
query = table
if index_col is not None:
for col in index_col:
query = query.order_by(col)
if index_col is not None and columns is not None:
pre_columns = []
for col in index_col:
if col not in columns:
pre_columns.append(col)
columns = pre_columns + columns
if keep_row_number and columns is not None:
if row_number_column in table_column_names and row_number_column not in columns:
columns = [row_number_column] + columns
elif not keep_row_number and columns is None:
if row_number_column in table_column_names:
columns = [col for col in table_column_names if col != row_number_column]
if columns is not None:
query = query.with_only_columns(*[table.columns.get(c) for c in columns])
def _to_native_type(x):
if checks.is_np_scalar(x):
return x.item()
return x
if start_row is not None or end_row is not None:
if start is not None or end is not None:
raise ValueError("Can either filter by row numbers or by index, not both")
_row_number_column = table.columns.get(row_number_column)
if _row_number_column is None:
raise ValueError(f"Row number column '{row_number_column}' not found")
and_list = []
if start_row is not None:
and_list.append(_row_number_column >= _to_native_type(start_row))
if end_row is not None:
and_list.append(_row_number_column < _to_native_type(end_row))
query = query.where(and_(*and_list))
if start is not None or end is not None:
if index_col is None:
raise ValueError("Must provide index column for filtering by start and end")
if align_dates:
first_obj = pd.read_sql_query(
query.limit(1),
engine.connect(),
index_col=index_col,
parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted
dtype=dtype,
chunksize=None,
**read_sql_kwargs,
)
first_obj = cls.prepare_dt(
first_obj,
parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates,
to_utc=False,
)
if isinstance(first_obj.index, pd.DatetimeIndex):
if tz is None:
tz = first_obj.index.tz
if first_obj.index.tz is not None:
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=first_obj.index.tz)
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=first_obj.index.tz)
else:
if start is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc)
):
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start = dt.to_naive_datetime(start)
else:
start = dt.to_naive_datetime(start, tz=tz)
if end is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc)
):
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end = dt.to_naive_datetime(end)
else:
end = dt.to_naive_datetime(end, tz=tz)
and_list = []
if start is not None:
if len(index_col) > 1:
if not isinstance(start, tuple):
raise TypeError("Start must be a tuple if the index is a multi-index")
if len(start) != len(index_col):
raise ValueError("Start tuple must match the number of levels in the multi-index")
for i in range(len(index_col)):
index_column = table.columns.get(index_col[i])
and_list.append(index_column >= _to_native_type(start[i]))
else:
index_column = table.columns.get(index_col[0])
and_list.append(index_column >= _to_native_type(start))
if end is not None:
if len(index_col) > 1:
if not isinstance(end, tuple):
raise TypeError("End must be a tuple if the index is a multi-index")
if len(end) != len(index_col):
raise ValueError("End tuple must match the number of levels in the multi-index")
for i in range(len(index_col)):
index_column = table.columns.get(index_col[i])
and_list.append(index_column < _to_native_type(end[i]))
else:
index_column = table.columns.get(index_col[0])
and_list.append(index_column < _to_native_type(end))
query = query.where(and_(*and_list))
else:
def _check_columns(c, arg_name):
if checks.is_int(c):
raise ValueError(f"Must provide column as a string for '{arg_name}'")
elif not isinstance(c, str):
for _c in c:
if checks.is_int(_c):
raise ValueError(f"Must provide each column as a string for '{arg_name}'")
if start is not None:
raise ValueError("Start cannot be applied to custom queries")
if end is not None:
raise ValueError("End cannot be applied to custom queries")
if start_row is not None:
raise ValueError("Start row cannot be applied to custom queries")
if end_row is not None:
raise ValueError("End row cannot be applied to custom queries")
if index_col is False:
index_col = None
if index_col is not None:
_check_columns(index_col, "index_col")
if isinstance(index_col, str):
index_col = [index_col]
if columns is not None:
raise ValueError("Columns cannot be applied to custom queries")
if parse_dates is not None:
if not isinstance(parse_dates, bool):
if isinstance(parse_dates, dict):
_check_columns(parse_dates.keys(), "parse_dates")
else:
_check_columns(parse_dates, "parse_dates")
if isinstance(parse_dates, str):
parse_dates = [parse_dates]
if dtype is not None:
_check_columns(dtype.keys(), "dtype")
if isinstance(query, str):
query = text(query)
obj = pd.read_sql_query(
query,
engine.connect(),
index_col=index_col,
parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted
dtype=dtype,
chunksize=chunksize,
**read_sql_kwargs,
)
if isinstance(obj, Iterator):
if chunk_func is None:
obj = pd.concat(list(obj), axis=0)
else:
obj = chunk_func(obj)
obj = cls.prepare_dt(
obj,
parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates,
to_utc=to_utc,
)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
if dispose_engine:
engine.dispose()
if keep_row_number:
return obj, dict(tz=tz, row_number_column=row_number_column)
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData:
"""Fetch the table of a feature.
Uses `SQLData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData:
"""Fetch the table for a symbol.
Uses `SQLData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(
self,
key: str,
from_last_row: tp.Optional[bool] = None,
from_last_index: tp.Optional[bool] = None,
**kwargs,
) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
returned_kwargs = self.select_returned_kwargs(key)
pre_kwargs = merge_dicts(fetch_kwargs, kwargs)
if from_last_row is None:
if pre_kwargs.get("query", None) is not None:
from_last_row = False
elif from_last_index is True:
from_last_row = False
elif pre_kwargs.get("start", None) is not None or pre_kwargs.get("end", None) is not None:
from_last_row = False
elif "row_number_column" not in returned_kwargs:
from_last_row = False
elif returned_kwargs["row_number_column"] not in self.wrapper.columns:
from_last_row = False
else:
from_last_row = True
if from_last_index is None:
if pre_kwargs.get("query", None) is not None:
from_last_index = False
elif from_last_row is True:
from_last_index = False
elif pre_kwargs.get("start_row", None) is not None or pre_kwargs.get("end_row", None) is not None:
from_last_index = False
else:
from_last_index = True
if from_last_row:
if "row_number_column" not in returned_kwargs:
raise ValueError("Argument row_number_column must be in returned_kwargs for from_last_row")
row_number_column = returned_kwargs["row_number_column"]
fetch_kwargs["start_row"] = self.data[key][row_number_column].iloc[-1]
if from_last_index:
fetch_kwargs["start"] = self.select_last_index(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if self.feature_oriented:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: str, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `SQLData.update_key`."""
return self.update_key(feature, **kwargs)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `SQLData.update_key`."""
return self.update_key(symbol, **kwargs)
</file>
<file path="data/custom/synthetic.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `SyntheticData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.custom import CustomData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"SyntheticData",
]
__pdoc__ = {}
class SyntheticData(CustomData):
"""Data class for fetching synthetic data.
Exposes an abstract class method `SyntheticData.generate_symbol`.
Everything else is taken care of."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.synthetic")
@classmethod
def generate_key(cls, key: tp.Key, index: tp.Index, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Abstract method to generate data of a feature or symbol."""
raise NotImplementedError
@classmethod
def generate_feature(cls, feature: tp.Feature, index: tp.Index, **kwargs) -> tp.FeatureData:
"""Abstract method to generate data of a feature.
Uses `SyntheticData.generate_key` with `key_is_feature=True`."""
return cls.generate_key(feature, index, key_is_feature=True, **kwargs)
@classmethod
def generate_symbol(cls, symbol: tp.Symbol, index: tp.Index, **kwargs) -> tp.SymbolData:
"""Abstract method to generate data for a symbol.
Uses `SyntheticData.generate_key` with `key_is_feature=False`."""
return cls.generate_key(symbol, index, key_is_feature=False, **kwargs)
@classmethod
def fetch_key(
cls,
key: tp.Symbol,
key_is_feature: bool = False,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
periods: tp.Optional[int] = None,
timeframe: tp.Optional[tp.FrequencyLike] = None,
tz: tp.TimezoneLike = None,
normalize: tp.Optional[bool] = None,
inclusive: tp.Optional[str] = None,
**kwargs,
) -> tp.KeyData:
"""Generate data of a feature or symbol.
Generates datetime index using `vectorbtpro.utils.datetime_.date_range` and passes it to
`SyntheticData.generate_key` to fill the Series/DataFrame with generated data.
For defaults, see `custom.synthetic` in `vectorbtpro._settings.data`."""
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
normalize = cls.resolve_custom_setting(normalize, "normalize")
inclusive = cls.resolve_custom_setting(inclusive, "inclusive")
index = dt.date_range(
start=start,
end=end,
periods=periods,
freq=timeframe,
normalize=normalize,
inclusive=inclusive,
)
if tz is None:
tz = index.tz
if len(index) == 0:
raise ValueError("Date range is empty")
if key_is_feature:
return cls.generate_feature(key, index, **kwargs), dict(tz=tz, freq=timeframe)
return cls.generate_symbol(key, index, **kwargs), dict(tz=tz, freq=timeframe)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Generate data of a feature.
Uses `SyntheticData.fetch_key` with `key_is_feature=True`."""
return cls.fetch_key(feature, key_is_feature=True, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Generate data for a symbol.
Uses `SyntheticData.fetch_key` with `key_is_feature=False`."""
return cls.fetch_key(symbol, key_is_feature=False, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
fetch_kwargs["start"] = self.select_last_index(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `SyntheticData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `SyntheticData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
</file>
<file path="data/custom/tv.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `TVData`."""
import datetime
import json
import math
import random
import re
import string
import time
import pandas as pd
import requests
from websocket import WebSocket
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts, Configured
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"TVClient",
"TVData",
]
SIGNIN_URL = "https://www.tradingview.com/accounts/signin/"
"""Sign-in URL."""
SEARCH_URL = (
"https://symbol-search.tradingview.com/symbol_search/v3/?"
"text={text}&"
"start={start}&"
"hl=1&"
"exchange={exchange}&"
"lang=en&"
"search_type=undefined&"
"domain=production&"
"sort_by_country=US"
)
"""Symbol search URL."""
SCAN_URL = "https://scanner.tradingview.com/{market}/scan"
"""Market scanner URL."""
ORIGIN_URL = "https://data.tradingview.com"
"""Origin URL."""
REFERER_URL = "https://www.tradingview.com"
"""Referer URL."""
WS_URL = "wss://data.tradingview.com/socket.io/websocket"
"""Websocket URL."""
PRO_WS_URL = "wss://prodata.tradingview.com/socket.io/websocket"
"""Websocket URL (Pro)."""
WS_TIMEOUT = 5
"""Websocket timeout."""
MARKET_LIST = [
"america",
"argentina",
"australia",
"austria",
"bahrain",
"bangladesh",
"belgium",
"brazil",
"canada",
"chile",
"china",
"colombia",
"cyprus",
"czech",
"denmark",
"egypt",
"estonia",
"euronext",
"finland",
"france",
"germany",
"greece",
"hongkong",
"hungary",
"iceland",
"india",
"indonesia",
"israel",
"italy",
"japan",
"kenya",
"korea",
"ksa",
"kuwait",
"latvia",
"lithuania",
"luxembourg",
"malaysia",
"mexico",
"morocco",
"netherlands",
"newzealand",
"nigeria",
"norway",
"pakistan",
"peru",
"philippines",
"poland",
"portugal",
"qatar",
"romania",
"rsa",
"russia",
"serbia",
"singapore",
"slovakia",
"spain",
"srilanka",
"sweden",
"switzerland",
"taiwan",
"thailand",
"tunisia",
"turkey",
"uae",
"uk",
"venezuela",
"vietnam",
]
"""List of markets supported by the market scanner (list may be incomplete)."""
FIELD_LIST = [
"name",
"description",
"logoid",
"update_mode",
"type",
"typespecs",
"close",
"pricescale",
"minmov",
"fractional",
"minmove2",
"currency",
"change",
"change_abs",
"Recommend.All",
"volume",
"Value.Traded",
"market_cap_basic",
"fundamental_currency_code",
"Perf.1Y.MarketCap",
"price_earnings_ttm",
"earnings_per_share_basic_ttm",
"number_of_employees_fy",
"sector",
"market",
]
"""List of fields supported by the market scanner (list may be incomplete)."""
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36"
"""User agent."""
class TVClient(Configured):
"""Client for TradingView."""
def __init__(
self,
username: tp.Optional[str] = None,
password: tp.Optional[str] = None,
auth_token: tp.Optional[str] = None,
**kwargs,
) -> None:
"""Client for TradingView."""
Configured.__init__(
self,
username=username,
password=password,
auth_token=auth_token,
**kwargs,
)
if auth_token is None:
auth_token = self.auth(username, password)
elif username is not None or password is not None:
raise ValueError("Must provide either username and password, or auth_token")
self._auth_token = auth_token
self._ws = None
self._session = self.generate_session()
self._chart_session = self.generate_chart_session()
@property
def auth_token(self) -> str:
"""Authentication token."""
return self._auth_token
@property
def ws(self) -> WebSocket:
"""Instance of `websocket.Websocket`."""
return self._ws
@property
def session(self) -> str:
"""Session."""
return self._session
@property
def chart_session(self) -> str:
"""Chart session."""
return self._chart_session
@classmethod
def auth(
cls,
username: tp.Optional[str] = None,
password: tp.Optional[str] = None,
) -> str:
"""Authenticate."""
if username is not None and password is not None:
data = {"username": username, "password": password, "remember": "on"}
headers = {"Referer": REFERER_URL, "User-Agent": USER_AGENT}
response = requests.post(url=SIGNIN_URL, data=data, headers=headers)
response.raise_for_status()
json = response.json()
if "user" not in json or "auth_token" not in json["user"]:
raise ValueError(json)
return json["user"]["auth_token"]
if username is not None or password is not None:
raise ValueError("Must provide both username and password")
return "unauthorized_user_token"
@classmethod
def generate_session(cls) -> str:
"""Generate session."""
stringLength = 12
letters = string.ascii_lowercase
random_string = "".join(random.choice(letters) for _ in range(stringLength))
return "qs_" + random_string
@classmethod
def generate_chart_session(cls) -> str:
"""Generate chart session."""
stringLength = 12
letters = string.ascii_lowercase
random_string = "".join(random.choice(letters) for _ in range(stringLength))
return "cs_" + random_string
def create_connection(self, pro_data: bool = True) -> None:
"""Create a websocket connection."""
from websocket import create_connection
if pro_data:
self._ws = create_connection(
PRO_WS_URL,
headers=json.dumps({"Origin": ORIGIN_URL}),
timeout=WS_TIMEOUT,
)
else:
self._ws = create_connection(
WS_URL,
headers=json.dumps({"Origin": ORIGIN_URL}),
timeout=WS_TIMEOUT,
)
@classmethod
def filter_raw_message(cls, text) -> tp.Tuple[str, str]:
"""Filter raw message."""
found = re.search('"m":"(.+?)",', text).group(1)
found2 = re.search('"p":(.+?"}"])}', text).group(1)
return found, found2
@classmethod
def prepend_header(cls, st: str) -> str:
"""Prepend a header."""
return "~m~" + str(len(st)) + "~m~" + st
@classmethod
def construct_message(cls, func: str, param_list: tp.List[str]) -> str:
"""Construct a message."""
return json.dumps({"m": func, "p": param_list}, separators=(",", ":"))
def create_message(self, func: str, param_list: tp.List[str]) -> str:
"""Create a message."""
return self.prepend_header(self.construct_message(func, param_list))
def send_message(self, func: str, param_list: tp.List[str]) -> None:
"""Send a message."""
m = self.create_message(func, param_list)
self.ws.send(m)
@classmethod
def convert_raw_data(cls, raw_data: str, symbol: str) -> pd.DataFrame:
"""Process raw data into a DataFrame."""
search_result = re.search(r'"s":\[(.+?)\}\]', raw_data)
if search_result is None:
raise ValueError("Couldn't parse data returned by TradingView: {}".format(raw_data))
out = search_result.group(1)
x = out.split(',{"')
data = list()
volume_data = True
for xi in x:
xi = re.split(r"\[|:|,|\]", xi)
ts = datetime.datetime.utcfromtimestamp(float(xi[4]))
row = [ts]
for i in range(5, 10):
# skip converting volume data if does not exists
if not volume_data and i == 9:
row.append(0.0)
continue
try:
row.append(float(xi[i]))
except ValueError:
volume_data = False
row.append(0.0)
data.append(row)
data = pd.DataFrame(data, columns=["datetime", "open", "high", "low", "close", "volume"])
data = data.set_index("datetime")
data.insert(0, "symbol", value=symbol)
return data
@classmethod
def format_symbol(cls, symbol: str, exchange: str, fut_contract: tp.Optional[int] = None) -> str:
"""Format a symbol."""
if ":" in symbol:
pass
elif fut_contract is None:
symbol = f"{exchange}:{symbol}"
elif isinstance(fut_contract, int):
symbol = f"{exchange}:{symbol}{fut_contract}!"
else:
raise ValueError(f"Invalid fut_contract: '{fut_contract}'")
return symbol
def get_hist(
self,
symbol: str,
exchange: str = "NSE",
interval: str = "1D",
fut_contract: tp.Optional[int] = None,
adjustment: str = "splits",
extended_session: bool = False,
pro_data: bool = True,
limit: int = 20000,
return_raw: bool = False,
) -> tp.Union[str, tp.Frame]:
"""Get historical data."""
symbol = self.format_symbol(symbol=symbol, exchange=exchange, fut_contract=fut_contract)
backadjustment = False
if symbol.endswith("!A"):
backadjustment = True
symbol = symbol.replace("!A", "!")
self.create_connection(pro_data=pro_data)
self.send_message("set_auth_token", [self.auth_token])
self.send_message("chart_create_session", [self.chart_session, ""])
self.send_message("quote_create_session", [self.session])
self.send_message(
"quote_set_fields",
[
self.session,
"ch",
"chp",
"current_session",
"description",
"local_description",
"language",
"exchange",
"fractional",
"is_tradable",
"lp",
"lp_time",
"minmov",
"minmove2",
"original_name",
"pricescale",
"pro_name",
"short_name",
"type",
"update_mode",
"volume",
"currency_code",
"rchp",
"rtc",
],
)
self.send_message("quote_add_symbols", [self.session, symbol, {"flags": ["force_permission"]}])
self.send_message("quote_fast_symbols", [self.session, symbol])
self.send_message(
"resolve_symbol",
[
self.chart_session,
"symbol_1",
'={"symbol":"'
+ symbol
+ '","adjustment":"'
+ adjustment
+ ("" if not backadjustment else '","backadjustment":"default')
+ '","session":'
+ ('"regular"' if not extended_session else '"extended"')
+ "}",
],
)
self.send_message("create_series", [self.chart_session, "s1", "s1", "symbol_1", interval, limit])
self.send_message("switch_timezone", [self.chart_session, "exchange"])
raw_data = ""
while True:
try:
result = self.ws.recv()
raw_data += result + "\n"
except Exception as e:
break
if "series_completed" in result:
break
if return_raw:
return raw_data
return self.convert_raw_data(raw_data, symbol)
@classmethod
def search_symbol(
cls,
text: tp.Optional[str] = None,
exchange: tp.Optional[str] = None,
pages: tp.Optional[int] = None,
delay: tp.Optional[int] = None,
retries: int = 3,
show_progress: bool = True,
pbar_kwargs: tp.KwargsLike = None,
) -> tp.List[dict]:
"""Search for a symbol."""
if text is None:
text = ""
if exchange is None:
exchange = ""
if pbar_kwargs is None:
pbar_kwargs = {}
symbols_list = []
pbar = None
pages_fetched = 0
while True:
for i in range(retries):
try:
url = SEARCH_URL.format(text=text, exchange=exchange.upper(), start=len(symbols_list))
headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT}
resp = requests.get(url, headers=headers)
symbols_data = json.loads(resp.text.replace("</em>", "").replace("<em>", ""))
break
except json.JSONDecodeError as e:
if i == retries - 1:
raise e
if delay is not None:
time.sleep(delay)
symbols_remaining = symbols_data.get("symbols_remaining", 0)
new_symbols = symbols_data.get("symbols", [])
symbols_list.extend(new_symbols)
if pages is None and symbols_remaining > 0:
show_pbar = True
elif pages is not None and pages > 1:
show_pbar = True
else:
show_pbar = False
if pbar is None and show_pbar:
if pages is not None:
total = pages
else:
total = math.ceil((len(new_symbols) + symbols_remaining) / len(new_symbols))
pbar = ProgressBar(
total=total,
show_progress=show_progress,
**pbar_kwargs,
)
pbar.enter()
if pbar is not None:
max_symbols = len(symbols_list) + symbols_remaining
if pages is not None:
max_symbols = min(max_symbols, pages * len(new_symbols))
pbar.set_description(dict(symbols="%d/%d" % (len(symbols_list), max_symbols)))
pbar.update()
if symbols_remaining == 0:
break
pages_fetched += 1
if pages is not None and pages_fetched >= pages:
break
if delay is not None:
time.sleep(delay)
if pbar is not None:
pbar.exit()
return symbols_list
@classmethod
def scan_symbols(cls, market: tp.Optional[str] = None, **kwargs) -> tp.List[dict]:
"""Scan symbols in a region/market."""
if market is None:
market = "global"
url = SCAN_URL.format(market=market.lower())
headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT}
resp = requests.post(url, json.dumps(kwargs), headers=headers)
symbols_list = json.loads(resp.text)["data"]
return symbols_list
TVDataT = tp.TypeVar("TVDataT", bound="TVData")
class TVData(RemoteData):
"""Data class for fetching from TradingView.
See `TVData.fetch_symbol` for arguments.
!!! note
If you're getting the error "Please confirm that you are not a robot by clicking the captcha box."
when attempting to authenticate, use `auth_token` instead of `username` and `password`.
To get the authentication token, go to TradingView, log in, visit any chart, open your console's
developer tools, and search for "auth_token".
Usage:
* Set up the credentials globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.TVData.set_custom_settings(
... client_config=dict(
... username="YOUR_USERNAME",
... password="YOUR_PASSWORD",
... auth_token="YOUR_AUTH_TOKEN", # optional, instead of username and password
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.TVData.pull(
... "NASDAQ:AAPL",
... timeframe="1 hour"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.tv")
@classmethod
def list_symbols(
cls,
*,
exchange_pattern: tp.Optional[str] = None,
symbol_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
client: tp.Optional[TVClient] = None,
client_config: tp.DictLike = None,
text: tp.Optional[str] = None,
exchange: tp.Optional[str] = None,
pages: tp.Optional[int] = None,
delay: tp.Optional[int] = None,
retries: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
market: tp.Optional[str] = None,
markets: tp.Optional[tp.List[str]] = None,
fields: tp.Optional[tp.MaybeIterable[str]] = None,
filter_by: tp.Union[None, tp.Callable, CustomTemplate] = None,
groups: tp.Optional[tp.MaybeIterable[tp.Dict[str, tp.MaybeIterable[str]]]] = None,
template_context: tp.KwargsLike = None,
return_field_data: bool = False,
**scanner_kwargs,
) -> tp.Union[tp.List[str], tp.List[tp.Kwargs]]:
"""List all symbols.
Uses symbol search when either `text` or `exchange` is provided (returns a subset of symbols).
Otherwise, uses the market scanner (returns all symbols, big payload).
When using the market scanner, use `market` to filter by one or multiple markets. For the list
of available markets, see `MARKET_LIST`.
Use `fields` to make the market scanner return additional information that can be used for
filtering with `filter_by`. Such information is passed to the function as a dictionary where
fields are keys. The function can also be a template that can use the same information provided
as a context, or a list of values that should be matched against the values corresponding to their fields.
For the list of available fields, see `FIELD_LIST`. Argument `fields` can also be "all".
Set `return_field_data` to True to return a list with (filtered) field data.
Use `groups` to provide a single dictionary or a list of dictionaries with groups.
Each dictionary can be provided either in a compressed format, such as `dict(index=index)`,
or in a full format, such as `dict(type="index", values=[index])`.
Keyword arguments `scanner_kwargs` are encoded and passed directly to the market scanner.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each exchange against
`exchange_pattern` and each symbol against `symbol_pattern`.
Usage:
* List all symbols (market scanner):
```pycon
>>> from vectorbtpro import *
>>> vbt.TVData.list_symbols()
```
* Search for symbols matching a pattern (market scanner, client-side):
```pycon
>>> vbt.TVData.list_symbols(symbol_pattern="BTC*")
```
* Search for exchanges matching a pattern (market scanner, client-side):
```pycon
>>> vbt.TVData.list_symbols(exchange_pattern="NASDAQ")
```
* Search for symbols containing a text (symbol search, server-side):
```pycon
>>> vbt.TVData.list_symbols(text="BTC")
```
* List symbols from an exchange (symbol search):
```pycon
>>> vbt.TVData.list_symbols(exchange="NASDAQ")
```
* List symbols from a market (market scanner):
```pycon
>>> vbt.TVData.list_symbols(market="poland")
```
* List index constituents (market scanner):
```pycon
>>> vbt.TVData.list_symbols(groups=dict(index="NASDAQ:NDX"))
```
* Filter symbols by fields using a function (market scanner):
```pycon
>>> vbt.TVData.list_symbols(
... market="america",
... fields=["sector"],
... filter_by=lambda context: context["sector"] == "Technology Services"
... )
```
* Filter symbols by fields using a template (market scanner):
```pycon
>>> vbt.TVData.list_symbols(
... market="america",
... fields=["sector"],
... filter_by=vbt.RepEval("sector == 'Technology Services'")
... )
```
"""
pages = cls.resolve_custom_setting(pages, "pages", sub_path="search", sub_path_only=True)
delay = cls.resolve_custom_setting(delay, "delay", sub_path="search", sub_path_only=True)
retries = cls.resolve_custom_setting(retries, "retries", sub_path="search", sub_path_only=True)
show_progress = cls.resolve_custom_setting(
show_progress, "show_progress", sub_path="search", sub_path_only=True
)
pbar_kwargs = cls.resolve_custom_setting(
pbar_kwargs, "pbar_kwargs", sub_path="search", sub_path_only=True, merge=True
)
markets = cls.resolve_custom_setting(markets, "markets", sub_path="scanner", sub_path_only=True)
fields = cls.resolve_custom_setting(fields, "fields", sub_path="scanner", sub_path_only=True)
filter_by = cls.resolve_custom_setting(filter_by, "filter_by", sub_path="scanner", sub_path_only=True)
groups = cls.resolve_custom_setting(groups, "groups", sub_path="scanner", sub_path_only=True)
template_context = cls.resolve_custom_setting(
template_context, "template_context", sub_path="scanner", sub_path_only=True, merge=True
)
scanner_kwargs = cls.resolve_custom_setting(
scanner_kwargs, "scanner_kwargs", sub_path="scanner", sub_path_only=True, merge=True
)
if market is None and text is None and exchange is None:
market = "global"
if market is not None and (text is not None or exchange is not None):
raise ValueError("Please provide either market, or text and/or exchange")
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
if market is None:
data = client.search_symbol(
text=text,
exchange=exchange,
pages=pages,
delay=delay,
retries=retries,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
)
all_symbols = map(lambda x: x["exchange"] + ":" + x["symbol"], data)
return_field_data = False
else:
if markets is not None:
scanner_kwargs["markets"] = markets
if fields is not None:
if "columns" in scanner_kwargs:
raise ValueError("Use fields instead of columns")
if isinstance(fields, str):
if fields.lower() == "all":
fields = FIELD_LIST
else:
fields = [fields]
scanner_kwargs["columns"] = fields
if groups is not None:
if isinstance(groups, dict):
groups = [groups]
new_groups = []
for group in groups:
if "type" in group:
new_groups.append(group)
else:
for k, v in group.items():
if isinstance(v, str):
v = [v]
new_groups.append(dict(type=k, values=v))
groups = new_groups
if "symbols" in scanner_kwargs:
scanner_kwargs["symbols"] = dict(scanner_kwargs["symbols"])
else:
scanner_kwargs["symbols"] = dict()
scanner_kwargs["symbols"]["groups"] = groups
if filter_by is not None:
if isinstance(filter_by, str):
filter_by = [filter_by]
data = client.scan_symbols(market.lower(), **scanner_kwargs)
if data is None:
raise ValueError("No data returned by TradingView")
all_symbols = []
for item in data:
if fields is not None:
item = {"symbol": item["s"], **dict(zip(fields, item["d"]))}
else:
item = {"symbol": item["s"]}
if filter_by is not None:
if fields is not None:
context = merge_dicts(item, template_context)
else:
raise ValueError("Must provide fields for filter_by")
if isinstance(filter_by, CustomTemplate):
if not filter_by.substitute(context, eval_id="filter_by"):
continue
elif callable(filter_by):
if not filter_by(context):
continue
else:
if len(fields) != len(filter_by):
raise ValueError("Fields and filter_by must have the same number of values")
conditions_met = True
for i in range(len(fields)):
if context[fields[i]] != filter_by[i]:
conditions_met = False
break
if not conditions_met:
continue
if return_field_data:
all_symbols.append(item)
else:
all_symbols.append(item["symbol"])
found_symbols = []
for symbol in all_symbols:
if return_field_data:
item = symbol
symbol = item["symbol"]
else:
item = symbol
if '"symbol"' in symbol:
continue
if exchange_pattern is not None:
if not cls.key_match(symbol.split(":")[0], exchange_pattern, use_regex=use_regex):
continue
if symbol_pattern is not None:
if not cls.key_match(symbol.split(":")[1], symbol_pattern, use_regex=use_regex):
continue
found_symbols.append(item)
if sort:
if return_field_data:
return sorted(found_symbols, key=lambda x: x["symbol"])
return sorted(dict.fromkeys(found_symbols))
if return_field_data:
return found_symbols
return list(dict.fromkeys(found_symbols))
@classmethod
def resolve_client(cls, client: tp.Optional[TVClient] = None, **client_config) -> TVClient:
"""Resolve the client.
If provided, must be of the type `TVClient`. Otherwise, will be created using `client_config`."""
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = TVClient(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[TVClient] = None,
client_config: tp.KwargsLike = None,
exchange: tp.Optional[str] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
fut_contract: tp.Optional[int] = None,
adjustment: tp.Optional[str] = None,
extended_session: tp.Optional[bool] = None,
pro_data: tp.Optional[bool] = None,
limit: tp.Optional[int] = None,
delay: tp.Optional[int] = None,
retries: tp.Optional[int] = None,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from TradingView.
Args:
symbol (str): Symbol.
Symbol must be in the `EXCHANGE:SYMBOL` format if `exchange` is None.
client (TVClient): Client.
See `TVData.resolve_client`.
client_config (dict): Client config.
See `TVData.resolve_client`.
exchange (str): Exchange.
Can be omitted if already provided via `symbol`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
fut_contract (int): None for cash, 1 for continuous current contract in front,
2 for continuous next contract in front.
adjustment (str): Adjustment.
Either "splits" (default) or "dividends".
extended_session (bool): Regular session if False, extended session if True.
pro_data (bool): Whether to use pro data.
limit (int): The maximum number of returned items.
delay (float): Time to sleep after each request (in seconds).
retries (int): The number of retries on failure to fetch data.
For defaults, see `custom.tv` in `vectorbtpro._settings.data`.
"""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
exchange = cls.resolve_custom_setting(exchange, "exchange")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
fut_contract = cls.resolve_custom_setting(fut_contract, "fut_contract")
adjustment = cls.resolve_custom_setting(adjustment, "adjustment")
extended_session = cls.resolve_custom_setting(extended_session, "extended_session")
pro_data = cls.resolve_custom_setting(pro_data, "pro_data")
limit = cls.resolve_custom_setting(limit, "limit")
delay = cls.resolve_custom_setting(delay, "delay")
retries = cls.resolve_custom_setting(retries, "retries")
freq = timeframe
if not isinstance(timeframe, str):
raise ValueError(f"Invalid timeframe: '{timeframe}'")
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "s":
interval = f"{str(multiplier)}S"
elif unit == "m":
interval = str(multiplier)
elif unit == "h":
interval = f"{str(multiplier)}H"
elif unit == "D":
interval = f"{str(multiplier)}D"
elif unit == "W":
interval = f"{str(multiplier)}W"
elif unit == "M":
interval = f"{str(multiplier)}M"
else:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
for i in range(retries):
try:
df = client.get_hist(
symbol=symbol,
exchange=exchange,
interval=interval,
fut_contract=fut_contract,
adjustment=adjustment,
extended_session=extended_session,
pro_data=pro_data,
limit=limit,
)
break
except Exception as e:
if i == retries - 1:
raise e
if delay is not None:
time.sleep(delay)
df.rename(
columns={
"symbol": "Symbol",
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"volume": "Volume",
},
inplace=True,
)
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
if "Symbol" in df:
del df["Symbol"]
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
</file>
<file path="data/custom/yf.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `YFData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.parsing import get_func_kwargs
__all__ = [
"YFData",
]
__pdoc__ = {}
class YFData(RemoteData):
"""Data class for fetching from Yahoo Finance.
See https://github.com/ranaroussi/yfinance for API.
See `YFData.fetch_symbol` for arguments.
Usage:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(
... "BTC-USD",
... start="2020-01-01",
... end="2021-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.yf")
_feature_config: tp.ClassVar[Config] = HybridConfig(
{
"Dividends": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
"Stock Splits": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.nonzero_prod_reduce_nb,
)
),
"Capital Gains": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
}
)
@property
def feature_config(self) -> Config:
return self._feature_config
@classmethod
def fetch_symbol(
cls,
symbol: str,
period: tp.Optional[str] = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
**history_kwargs,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Yahoo Finance.
Args:
symbol (str): Symbol.
period (str): Period.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
**history_kwargs: Keyword arguments passed to `yfinance.base.TickerBase.history`.
For defaults, see `custom.yf` in `vectorbtpro._settings.data`.
!!! warning
Data coming from Yahoo is not the most stable data out there. Yahoo may manipulate data
how they want, add noise, return missing data points (see volume in the example below), etc.
It's only used in vectorbt for demonstration purposes.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("yfinance")
import yfinance as yf
period = cls.resolve_custom_setting(period, "period")
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
history_kwargs = cls.resolve_custom_setting(history_kwargs, "history_kwargs", merge=True)
ticker = yf.Ticker(symbol)
def_history_kwargs = get_func_kwargs(yf.Tickers.history)
ticker_tz = ticker._get_ticker_tz(
history_kwargs.get("proxy", def_history_kwargs["proxy"]),
history_kwargs.get("timeout", def_history_kwargs["timeout"]),
)
if tz is None:
tz = ticker_tz
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=ticker_tz)
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=ticker_tz)
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
if unit == "D":
unit = "d"
elif unit == "W":
unit = "wk"
elif unit == "M":
unit = "mo"
timeframe = str(multiplier) + unit
df = ticker.history(period=period, start=start, end=end, interval=timeframe, **history_kwargs)
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize(ticker_tz)
if not df.empty:
if start is not None:
if df.index[0] < start:
df = df[df.index >= start]
if end is not None:
if df.index[-1] >= end:
df = df[df.index < end]
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
YFData.override_feature_config_doc(__pdoc__)
</file>
<file path="data/__init__.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with data sources."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.data.base import *
from vectorbtpro.data.custom import *
from vectorbtpro.data.decorators import *
from vectorbtpro.data.nb import *
from vectorbtpro.data.saver import *
from vectorbtpro.data.updater import *
</file>
<file path="data/base.py">
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with data sources."""
import inspect
import string
import traceback
from pathlib import Path
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import stack_indexes
from vectorbtpro.base.merging import column_stack_arrays, is_merge_func_from_config
from vectorbtpro.base.reshaping import to_any_array, to_pd_array, to_1d_array, to_2d_array, broadcast_to
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.data.decorators import attach_symbol_dict_methods
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.generic.drawdowns import Drawdowns
from vectorbtpro.ohlcv.nb import mirror_ohlc_nb
from vectorbtpro.ohlcv.enums import PriceFeature
from vectorbtpro.returns.accessors import ReturnsAccessor
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.attr_ import get_dict_attr
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig, copy_dict
from vectorbtpro.utils.decorators import cached_property, hybrid_method
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.execution import Task, NoResult, NoResultsException, filter_out_no_results, execute
from vectorbtpro.utils.merging import MergeFunc
from vectorbtpro.utils.parsing import get_func_arg_names, extend_args
from vectorbtpro.utils.path_ import check_mkdir
from vectorbtpro.utils.pickling import pdict, RecState
from vectorbtpro.utils.template import Rep, RepEval, CustomTemplate, substitute_templates
from vectorbtpro.utils.warnings_ import warn
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
try:
if not tp.TYPE_CHECKING:
raise ImportError
from sqlalchemy import Engine as EngineT
except ImportError:
EngineT = "Engine"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from duckdb import DuckDBPyConnection as DuckDBPyConnectionT
except ImportError:
DuckDBPyConnectionT = "DuckDBPyConnection"
__all__ = [
"key_dict",
"feature_dict",
"symbol_dict",
"run_func_dict",
"run_arg_dict",
"Data",
]
__pdoc__ = {}
class key_dict(pdict):
"""Dict that contains features or symbols as keys."""
pass
class feature_dict(key_dict):
"""Dict that contains features as keys."""
pass
class symbol_dict(key_dict):
"""Dict that contains symbols as keys."""
pass
class run_func_dict(pdict):
"""Dict that contains function names as keys for `Data.run`."""
pass
class run_arg_dict(pdict):
"""Dict that contains argument names as keys for `Data.run`."""
pass
BaseDataMixinT = tp.TypeVar("BaseDataMixinT", bound="BaseDataMixin")
class BaseDataMixin(Base):
"""Base mixin class for working with data."""
@property
def feature_wrapper(self) -> ArrayWrapper:
"""Column wrapper."""
raise NotImplementedError
@property
def symbol_wrapper(self) -> ArrayWrapper:
"""Symbol wrapper."""
raise NotImplementedError
@property
def features(self) -> tp.List[tp.Feature]:
"""List of features."""
return self.feature_wrapper.columns.tolist()
@property
def symbols(self) -> tp.List[tp.Symbol]:
"""List of symbols."""
return self.symbol_wrapper.columns.tolist()
@classmethod
def has_multiple_keys(cls, keys: tp.MaybeKeys) -> bool:
"""Check whether there are one or multiple keys."""
if checks.is_hashable(keys):
return False
elif checks.is_sequence(keys):
return True
raise TypeError("Keys must be either a hashable or a sequence of hashable")
@classmethod
def prepare_key(cls, key: tp.Key) -> tp.Key:
"""Prepare a key."""
if isinstance(key, tuple):
return tuple([cls.prepare_key(k) for k in key])
if isinstance(key, str):
return key.lower().strip().replace(" ", "_")
return key
def get_feature_idx(self, feature: tp.Feature, raise_error: bool = False) -> int:
"""Return the index of a feature."""
# shortcut
columns = self.feature_wrapper.columns
if not columns.has_duplicates:
if feature in columns:
return columns.get_loc(feature)
feature = self.prepare_key(feature)
found_indices = []
for i, c in enumerate(self.features):
c = self.prepare_key(c)
if feature == c:
found_indices.append(i)
if len(found_indices) == 0:
if raise_error:
raise ValueError(f"No features match the feature '{str(feature)}'")
return -1
if len(found_indices) == 1:
return found_indices[0]
raise ValueError(f"Multiple features match the feature '{str(feature)}'")
def get_symbol_idx(self, symbol: tp.Symbol, raise_error: bool = False) -> int:
"""Return the index of a symbol."""
# shortcut
columns = self.symbol_wrapper.columns
if not columns.has_duplicates:
if symbol in columns:
return columns.get_loc(symbol)
symbol = self.prepare_key(symbol)
found_indices = []
for i, c in enumerate(self.symbols):
c = self.prepare_key(c)
if symbol == c:
found_indices.append(i)
if len(found_indices) == 0:
if raise_error:
raise ValueError(f"No symbols match the symbol '{str(symbol)}'")
return -1
if len(found_indices) == 1:
return found_indices[0]
raise ValueError(f"Multiple symbols match the symbol '{str(symbol)}'")
def select_feature_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT:
"""Select one or more features by index.
Returns a new instance."""
raise NotImplementedError
def select_symbol_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT:
"""Select one or more symbols by index.
Returns a new instance."""
raise NotImplementedError
def select_features(self: BaseDataMixinT, features: tp.MaybeFeatures, **kwargs) -> BaseDataMixinT:
"""Select one or more features.
Returns a new instance."""
if self.has_multiple_keys(features):
feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features]
else:
feature_idxs = self.get_feature_idx(features, raise_error=True)
return self.select_feature_idxs(feature_idxs, **kwargs)
def select_symbols(self: BaseDataMixinT, symbols: tp.MaybeSymbols, **kwargs) -> BaseDataMixinT:
"""Select one or more symbols.
Returns a new instance."""
if self.has_multiple_keys(symbols):
symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols]
else:
symbol_idxs = self.get_symbol_idx(symbols, raise_error=True)
return self.select_symbol_idxs(symbol_idxs, **kwargs)
def get(
self,
features: tp.Optional[tp.MaybeFeatures] = None,
symbols: tp.Optional[tp.MaybeSymbols] = None,
feature: tp.Optional[tp.Feature] = None,
symbol: tp.Optional[tp.Symbol] = None,
**kwargs,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Get one or more features of one or more symbols of data."""
raise NotImplementedError
def has_feature(self, feature: tp.Feature) -> bool:
"""Whether feature exists."""
feature_idx = self.get_feature_idx(feature, raise_error=False)
return feature_idx != -1
def has_symbol(self, symbol: tp.Symbol) -> bool:
"""Whether symbol exists."""
symbol_idx = self.get_symbol_idx(symbol, raise_error=False)
return symbol_idx != -1
def assert_has_feature(self, feature: tp.Feature) -> None:
"""Assert that feature exists."""
self.get_feature_idx(feature, raise_error=True)
def assert_has_symbol(self, symbol: tp.Symbol) -> None:
"""Assert that symbol exists."""
self.get_symbol_idx(symbol, raise_error=True)
def get_feature(
self,
feature: tp.Union[int, tp.Feature],
raise_error: bool = False,
) -> tp.Optional[tp.SeriesFrame]:
"""Get feature that match a feature index or label."""
if checks.is_int(feature):
return self.get(features=self.features[feature])
feature_idx = self.get_feature_idx(feature, raise_error=raise_error)
if feature_idx == -1:
return None
return self.get(features=self.features[feature_idx])
def get_symbol(
self,
symbol: tp.Union[int, tp.Symbol],
raise_error: bool = False,
) -> tp.Optional[tp.SeriesFrame]:
"""Get symbol that match a symbol index or label."""
if checks.is_int(symbol):
return self.get(symbol=self.symbols[symbol])
symbol_idx = self.get_symbol_idx(symbol, raise_error=raise_error)
if symbol_idx == -1:
return None
return self.get(symbol=self.symbols[symbol_idx])
OHLCDataMixinT = tp.TypeVar("OHLCDataMixinT", bound="OHLCDataMixin")
class OHLCDataMixin(BaseDataMixin):
"""Mixin class for working with OHLC data."""
@property
def open(self) -> tp.Optional[tp.SeriesFrame]:
"""Open."""
return self.get_feature("Open")
@property
def high(self) -> tp.Optional[tp.SeriesFrame]:
"""High."""
return self.get_feature("High")
@property
def low(self) -> tp.Optional[tp.SeriesFrame]:
"""Low."""
return self.get_feature("Low")
@property
def close(self) -> tp.Optional[tp.SeriesFrame]:
"""Close."""
return self.get_feature("Close")
@property
def volume(self) -> tp.Optional[tp.SeriesFrame]:
"""Volume."""
return self.get_feature("Volume")
@property
def trade_count(self) -> tp.Optional[tp.SeriesFrame]:
"""Trade count."""
return self.get_feature("Trade count")
@property
def vwap(self) -> tp.Optional[tp.SeriesFrame]:
"""VWAP."""
return self.get_feature("VWAP")
@property
def hlc3(self) -> tp.SeriesFrame:
"""HLC/3."""
high = self.get_feature("High", raise_error=True)
low = self.get_feature("Low", raise_error=True)
close = self.get_feature("Close", raise_error=True)
return (high + low + close) / 3
@property
def ohlc4(self) -> tp.SeriesFrame:
"""OHLC/4."""
open = self.get_feature("Open", raise_error=True)
high = self.get_feature("High", raise_error=True)
low = self.get_feature("Low", raise_error=True)
close = self.get_feature("Close", raise_error=True)
return (open + high + low + close) / 4
@property
def has_any_ohlc(self) -> bool:
"""Whether the instance has any of the OHLC features."""
return (
self.has_feature("Open") or self.has_feature("High") or self.has_feature("Low") or self.has_feature("Close")
)
@property
def has_ohlc(self) -> bool:
"""Whether the instance has all the OHLC features."""
return (
self.has_feature("Open")
and self.has_feature("High")
and self.has_feature("Low")
and self.has_feature("Close")
)
@property
def has_any_ohlcv(self) -> bool:
"""Whether the instance has any of the OHLCV features."""
return self.has_any_ohlc or self.has_feature("Volume")
@property
def has_ohlcv(self) -> bool:
"""Whether the instance has all the OHLCV features."""
return self.has_ohlc and self.has_feature("Volume")
@property
def ohlc(self: OHLCDataMixinT) -> OHLCDataMixinT:
"""Return a `OHLCDataMixin` instance with the OHLC features only."""
open_idx = self.get_feature_idx("Open", raise_error=True)
high_idx = self.get_feature_idx("High", raise_error=True)
low_idx = self.get_feature_idx("Low", raise_error=True)
close_idx = self.get_feature_idx("Close", raise_error=True)
return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx])
@property
def ohlcv(self: OHLCDataMixinT) -> OHLCDataMixinT:
"""Return a `OHLCDataMixin` instance with the OHLCV features only."""
open_idx = self.get_feature_idx("Open", raise_error=True)
high_idx = self.get_feature_idx("High", raise_error=True)
low_idx = self.get_feature_idx("Low", raise_error=True)
close_idx = self.get_feature_idx("Close", raise_error=True)
volume_idx = self.get_feature_idx("Volume", raise_error=True)
return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx, volume_idx])
def get_returns_acc(self, **kwargs) -> ReturnsAccessor:
"""Return accessor of type `vectorbtpro.returns.accessors.ReturnsAccessor`."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=False,
**kwargs,
)
@property
def returns_acc(self) -> ReturnsAccessor:
"""`OHLCDataMixin.get_returns_acc` with default arguments."""
return self.get_returns_acc()
def get_returns(self, **kwargs) -> tp.SeriesFrame:
"""Returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=True,
**kwargs,
)
@property
def returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_returns` with default arguments."""
return self.get_returns()
def get_log_returns(self, **kwargs) -> tp.SeriesFrame:
"""Log returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=True,
log_returns=True,
**kwargs,
)
@property
def log_returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_log_returns` with default arguments."""
return self.get_log_returns()
def get_daily_returns(self, **kwargs) -> tp.SeriesFrame:
"""Daily returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=False,
**kwargs,
).daily()
@property
def daily_returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_daily_returns` with default arguments."""
return self.get_daily_returns()
def get_daily_log_returns(self, **kwargs) -> tp.SeriesFrame:
"""Daily log returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=False,
log_returns=True,
**kwargs,
).daily()
@property
def daily_log_returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_daily_log_returns` with default arguments."""
return self.get_daily_log_returns()
def get_drawdowns(self, **kwargs) -> Drawdowns:
"""Generate drawdown records.
See `vectorbtpro.generic.drawdowns.Drawdowns`."""
return Drawdowns.from_price(
open=self.get_feature("Open", raise_error=True),
high=self.get_feature("High", raise_error=True),
low=self.get_feature("Low", raise_error=True),
close=self.get_feature("Close", raise_error=True),
**kwargs,
)
@property
def drawdowns(self) -> Drawdowns:
"""`OHLCDataMixin.get_drawdowns` with default arguments."""
return self.get_drawdowns()
DataT = tp.TypeVar("DataT", bound="Data")
class MetaData(type(Analyzable)):
"""Metaclass for `Data`."""
@property
def feature_config(cls) -> Config:
"""Feature config."""
return cls._feature_config
@attach_symbol_dict_methods
class Data(Analyzable, OHLCDataMixin, metaclass=MetaData):
"""Class that downloads, updates, and manages data coming from a data source."""
_settings_path: tp.SettingsPath = dict(base="data")
_writeable_attrs: tp.WriteableAttrs = {"_feature_config"}
_feature_config: tp.ClassVar[Config] = HybridConfig()
_key_dict_attrs = [
"fetch_kwargs",
"returned_kwargs",
"last_index",
"delisted",
"classes",
]
"""Attributes that subclass either `feature_dict` or `symbol_dict`."""
_data_dict_type_attrs = [
"classes",
]
"""Attributes that subclass the data dict type."""
_updatable_attrs = [
"fetch_kwargs",
"returned_kwargs",
"classes",
]
"""Attributes that have a method for updating."""
@property
def feature_config(self) -> Config:
"""Column config of `${cls_name}`.
```python
${feature_config}
```
Returns `${cls_name}._feature_config`, which gets (hybrid-) copied upon creation of each instance.
Thus, changing this config won't affect the class.
To change fields, you can either change the config in-place, override this property,
or overwrite the instance variable `${cls_name}._feature_config`.
"""
return self._feature_config
def use_feature_config_of(self, cls: tp.Type[DataT]) -> None:
"""Copy feature config from another `Data` class."""
self._feature_config = cls.feature_config.copy()
@classmethod
def modify_state(cls, rec_state: RecState) -> RecState:
# Ensure backward compatibility
if "_column_config" in rec_state.attr_dct and "_feature_config" not in rec_state.attr_dct:
new_attr_dct = dict(rec_state.attr_dct)
new_attr_dct["_feature_config"] = new_attr_dct.pop("_column_config")
rec_state = RecState(
init_args=rec_state.init_args,
init_kwargs=rec_state.init_kwargs,
attr_dct=new_attr_dct,
)
if "single_symbol" in rec_state.init_kwargs and "single_key" not in rec_state.init_kwargs:
new_init_kwargs = dict(rec_state.init_kwargs)
new_init_kwargs["single_key"] = new_init_kwargs.pop("single_symbol")
rec_state = RecState(
init_args=rec_state.init_args,
init_kwargs=new_init_kwargs,
attr_dct=rec_state.attr_dct,
)
if "symbol_classes" in rec_state.init_kwargs and "classes" not in rec_state.init_kwargs:
new_init_kwargs = dict(rec_state.init_kwargs)
new_init_kwargs["classes"] = new_init_kwargs.pop("symbol_classes")
rec_state = RecState(
init_args=rec_state.init_args,
init_kwargs=new_init_kwargs,
attr_dct=rec_state.attr_dct,
)
return rec_state
@classmethod
def fix_data_dict_type(cls, data: dict) -> tp.Union[feature_dict, symbol_dict]:
"""Fix dict type for data."""
checks.assert_instance_of(data, dict, arg_name="data")
if not isinstance(data, key_dict):
data = symbol_dict(data)
return data
@classmethod
def fix_dict_types_in_kwargs(
cls,
data_type: tp.Type[tp.Union[feature_dict, symbol_dict]],
**kwargs: tp.Kwargs,
) -> tp.Kwargs:
"""Fix dict types in keyword arguments."""
for attr in cls._key_dict_attrs:
if attr in kwargs:
attr_value = kwargs[attr]
if attr_value is None:
attr_value = {}
checks.assert_instance_of(attr_value, dict, arg_name=attr)
if not isinstance(attr_value, key_dict):
attr_value = data_type(attr_value)
if attr in cls._data_dict_type_attrs:
checks.assert_instance_of(attr_value, data_type, arg_name=attr)
kwargs[attr] = attr_value
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[DataT],
*objs: tp.MaybeTuple[DataT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Stack multiple `Data` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Data):
raise TypeError("Each object to be merged must be an instance of Data")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
keys = set()
for obj in objs:
keys = keys.union(set(obj.data.keys()))
data_type = None
for obj in objs:
if len(keys.difference(set(obj.data.keys()))) > 0:
if isinstance(obj.data, feature_dict):
raise ValueError("Objects to be merged must have the same features")
else:
raise ValueError("Objects to be merged must have the same symbols")
if data_type is None:
data_type = type(obj.data)
elif not isinstance(obj.data, data_type):
raise TypeError("Objects to be merged must have the same dict type for data")
if "data" not in kwargs:
new_data = data_type()
for k in objs[0].data.keys():
new_data[k] = kwargs["wrapper"].row_stack_arrs(*[obj.data[k] for obj in objs], group_by=False)
kwargs["data"] = new_data
kwargs["data"] = cls.fix_data_dict_type(kwargs["data"])
for attr in cls._key_dict_attrs:
if attr not in kwargs:
attr_data_type = None
for obj in objs:
v = getattr(obj, attr)
if attr_data_type is None:
attr_data_type = type(v)
elif not isinstance(v, attr_data_type):
raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'")
kwargs[attr] = getattr(objs[-1], attr)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[DataT],
*objs: tp.MaybeTuple[DataT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Stack multiple `Data` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Data):
raise TypeError("Each object to be merged must be an instance of Data")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
**wrapper_kwargs,
)
keys = set()
for obj in objs:
keys = keys.union(set(obj.data.keys()))
data_type = None
for obj in objs:
if len(keys.difference(set(obj.data.keys()))) > 0:
if isinstance(obj.data, feature_dict):
raise ValueError("Objects to be merged must have the same features")
else:
raise ValueError("Objects to be merged must have the same symbols")
if data_type is None:
data_type = type(obj.data)
elif not isinstance(obj.data, data_type):
raise TypeError("Objects to be merged must have the same dict type for data")
if "data" not in kwargs:
new_data = data_type()
for k in objs[0].data.keys():
new_data[k] = kwargs["wrapper"].column_stack_arrs(*[obj.data[k] for obj in objs], group_by=False)
kwargs["data"] = new_data
kwargs["data"] = cls.fix_data_dict_type(kwargs["data"])
for attr in cls._key_dict_attrs:
if attr not in kwargs:
attr_data_type = None
for obj in objs:
v = getattr(obj, attr)
if attr_data_type is None:
attr_data_type = type(v)
elif not isinstance(v, attr_data_type):
raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'")
if (issubclass(data_type, feature_dict) and issubclass(attr_data_type, symbol_dict)) or (
issubclass(data_type, symbol_dict) and issubclass(attr_data_type, feature_dict)
):
kwargs[attr] = attr_data_type()
for obj in objs:
kwargs[attr].update(**getattr(obj, attr))
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
data: tp.Union[feature_dict, symbol_dict],
single_key: bool = True,
classes: tp.Union[None, feature_dict, symbol_dict] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
fetch_kwargs: tp.Union[None, feature_dict, symbol_dict] = None,
returned_kwargs: tp.Union[None, feature_dict, symbol_dict] = None,
last_index: tp.Union[None, feature_dict, symbol_dict] = None,
delisted: tp.Union[None, feature_dict, symbol_dict] = None,
tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None,
tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None,
missing_index: tp.Optional[str] = None,
missing_columns: tp.Optional[str] = None,
**kwargs,
) -> None:
Analyzable.__init__(
self,
wrapper,
data=data,
single_key=single_key,
classes=classes,
level_name=level_name,
fetch_kwargs=fetch_kwargs,
returned_kwargs=returned_kwargs,
last_index=last_index,
delisted=delisted,
tz_localize=tz_localize,
tz_convert=tz_convert,
missing_index=missing_index,
missing_columns=missing_columns,
**kwargs,
)
if len(set(map(self.prepare_key, data.keys()))) < len(list(map(self.prepare_key, data.keys()))):
raise ValueError("Found duplicate keys in data dictionary")
data = self.fix_data_dict_type(data)
for obj in data.values():
checks.assert_meta_equal(obj, data[list(data.keys())[0]])
if len(data) > 1:
single_key = False
self._data = data
self._single_key = single_key
self._level_name = level_name
self._tz_localize = tz_localize
self._tz_convert = tz_convert
self._missing_index = missing_index
self._missing_columns = missing_columns
attr_kwargs = dict()
for attr in self._key_dict_attrs:
attr_value = locals()[attr]
attr_kwargs[attr] = attr_value
attr_kwargs = self.fix_dict_types_in_kwargs(type(data), **attr_kwargs)
for k, v in attr_kwargs.items():
setattr(self, "_" + k, v)
# Copy writeable attrs
self._feature_config = type(self)._feature_config.copy()
def replace(self: DataT, **kwargs) -> DataT:
"""See `vectorbtpro.utils.config.Configured.replace`.
Replaces the data's index and/or columns if they were changed in the wrapper."""
if "wrapper" in kwargs and "data" not in kwargs:
wrapper = kwargs["wrapper"]
if isinstance(wrapper, dict):
new_index = wrapper.get("index", self.wrapper.index)
new_columns = wrapper.get("columns", self.wrapper.columns)
else:
new_index = wrapper.index
new_columns = wrapper.columns
data = self.config["data"]
new_data = {}
index_changed = False
columns_changed = False
for k, v in data.items():
if isinstance(v, (pd.Series, pd.DataFrame)):
if not checks.is_index_equal(v.index, new_index):
v = v.copy(deep=False)
v.index = new_index
index_changed = True
if isinstance(v, pd.DataFrame):
if not checks.is_index_equal(v.columns, new_columns):
v = v.copy(deep=False)
v.columns = new_columns
columns_changed = True
new_data[k] = v
if index_changed or columns_changed:
kwargs["data"] = self.fix_data_dict_type(new_data)
if columns_changed:
rename = dict(zip(self.keys, new_columns))
for attr in self._key_dict_attrs:
if attr not in kwargs:
attr_value = getattr(self, attr)
if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or (
self.symbol_oriented and isinstance(attr_value, feature_dict)
):
kwargs[attr] = self.rename_in_dict(getattr(self, attr), rename)
kwargs = self.fix_dict_types_in_kwargs(type(kwargs.get("data", self.data)), **kwargs)
return Analyzable.replace(self, **kwargs)
def indexing_func(self: DataT, *args, replace_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Perform indexing on `Data`."""
if replace_kwargs is None:
replace_kwargs = {}
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
new_wrapper = wrapper_meta["new_wrapper"]
new_data = self.dict_type()
for k, v in self._data.items():
if wrapper_meta["rows_changed"]:
v = v.iloc[wrapper_meta["row_idxs"]]
if wrapper_meta["columns_changed"]:
v = v.iloc[:, wrapper_meta["col_idxs"]]
new_data[k] = v
attr_dicts = dict()
attr_dicts["last_index"] = type(self.last_index)()
for k in self.last_index:
attr_dicts["last_index"][k] = min([self.last_index[k], new_wrapper.index[-1]])
if wrapper_meta["columns_changed"]:
new_symbols = new_wrapper.columns
for attr in self._key_dict_attrs:
attr_value = getattr(self, attr)
if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or (
self.symbol_oriented and isinstance(attr_value, feature_dict)
):
if attr in attr_dicts:
attr_dicts[attr] = self.select_from_dict(attr_dicts[attr], new_symbols)
else:
attr_dicts[attr] = self.select_from_dict(attr_value, new_symbols)
return self.replace(wrapper=new_wrapper, data=new_data, **attr_dicts, **replace_kwargs)
@property
def data(self) -> tp.Union[feature_dict, symbol_dict]:
"""Data dictionary.
Has the type `feature_dict` for feature-oriented data or `symbol_dict` for symbol-oriented data."""
return self._data
@property
def dict_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]:
"""Return the dict type."""
return type(self.data)
@property
def column_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]:
"""Return the column type."""
if isinstance(self.data, feature_dict):
return symbol_dict
return feature_dict
@property
def feature_oriented(self) -> bool:
"""Whether data has features as keys."""
return issubclass(self.dict_type, feature_dict)
@property
def symbol_oriented(self) -> bool:
"""Whether data has symbols as keys."""
return issubclass(self.dict_type, symbol_dict)
def get_keys(self, dict_type: tp.Type[tp.Union[feature_dict, symbol_dict]]) -> tp.List[tp.Key]:
"""Get keys depending on the provided dict type."""
checks.assert_subclass_of(dict_type, (feature_dict, symbol_dict), arg_name="dict_type")
if issubclass(dict_type, feature_dict):
return self.features
return self.symbols
@property
def keys(self) -> tp.List[tp.Union[tp.Feature, tp.Symbol]]:
"""Keys in data.
Features if `feature_dict` and symbols if `symbol_dict`."""
return list(self.data.keys())
@property
def single_key(self) -> bool:
"""Whether there is only one key in `Data.data`."""
return self._single_key
@property
def single_feature(self) -> bool:
"""Whether there is only one feature in `Data.data`."""
if self.feature_oriented:
return self.single_key
return self.wrapper.ndim == 1
@property
def single_symbol(self) -> bool:
"""Whether there is only one symbol in `Data.data`."""
if self.symbol_oriented:
return self.single_key
return self.wrapper.ndim == 1
@property
def classes(self) -> tp.Union[feature_dict, symbol_dict]:
"""Key classes."""
return self._classes
@property
def feature_classes(self) -> tp.Optional[feature_dict]:
"""Feature classes."""
if self.feature_oriented:
return self.classes
return None
@property
def symbol_classes(self) -> tp.Optional[symbol_dict]:
"""Symbol classes."""
if self.symbol_oriented:
return self.classes
return None
@hybrid_method
def get_level_name(
cls_or_self,
keys: tp.Optional[tp.Keys] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
feature_oriented: tp.Optional[bool] = None,
) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]:
"""Get level name(s) for keys."""
if isinstance(cls_or_self, type):
checks.assert_not_none(keys, arg_name="keys")
checks.assert_not_none(feature_oriented, arg_name="feature_oriented")
else:
if keys is None:
keys = cls_or_self.keys
if level_name is None:
level_name = cls_or_self._level_name
if feature_oriented is None:
feature_oriented = cls_or_self.feature_oriented
first_key = keys[0]
if isinstance(level_name, bool):
if level_name:
level_name = None
else:
return None
if feature_oriented:
key_prefix = "feature"
else:
key_prefix = "symbol"
if isinstance(first_key, tuple):
if level_name is None:
level_name = ["%s_%d" % (key_prefix, i) for i in range(len(first_key))]
if not checks.is_iterable(level_name) or isinstance(level_name, str):
raise TypeError("Level name should be list-like for a MultiIndex")
return tuple(level_name)
if level_name is None:
level_name = key_prefix
return level_name
@property
def level_name(self) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]:
"""Level name(s) for keys.
Keys are symbols or features depending on the data dict type.
Must be a sequence if keys are tuples, otherwise a hashable.
If False, no level names will be used."""
return self.get_level_name()
@hybrid_method
def get_key_index(
cls_or_self,
keys: tp.Optional[tp.Keys] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
feature_oriented: tp.Optional[bool] = None,
) -> tp.Index:
"""Get key index."""
if isinstance(cls_or_self, type):
checks.assert_not_none(keys, arg_name="keys")
else:
if keys is None:
keys = cls_or_self.keys
level_name = cls_or_self.get_level_name(keys=keys, level_name=level_name, feature_oriented=feature_oriented)
if isinstance(level_name, tuple):
return pd.MultiIndex.from_tuples(keys, names=level_name)
return pd.Index(keys, name=level_name)
@property
def key_index(self) -> tp.Index:
"""Key index."""
return self.get_key_index()
@property
def fetch_kwargs(self) -> tp.Union[feature_dict, symbol_dict]:
"""Keyword arguments of type `symbol_dict` initially passed to `Data.fetch_symbol`."""
return self._fetch_kwargs
@property
def returned_kwargs(self) -> tp.Union[feature_dict, symbol_dict]:
"""Keyword arguments of type `symbol_dict` returned by `Data.fetch_symbol`."""
return self._returned_kwargs
@property
def last_index(self) -> tp.Union[feature_dict, symbol_dict]:
"""Last fetched index per symbol of type `symbol_dict`."""
return self._last_index
@property
def delisted(self) -> tp.Union[feature_dict, symbol_dict]:
"""Delisted flag per symbol of type `symbol_dict`."""
return self._delisted
@property
def tz_localize(self) -> tp.Union[None, bool, tp.TimezoneLike]:
"""Timezone to localize a datetime-naive index to, which is initially passed to `Data.pull`."""
return self._tz_localize
@property
def tz_convert(self) -> tp.Union[None, bool, tp.TimezoneLike]:
"""Timezone to convert a datetime-aware to, which is initially passed to `Data.pull`."""
return self._tz_convert
@property
def missing_index(self) -> tp.Optional[str]:
"""Argument `missing` passed to `Data.align_index`."""
return self._missing_index
@property
def missing_columns(self) -> tp.Optional[str]:
"""Argument `missing` passed to `Data.align_columns`."""
return self._missing_columns
# ############# Settings ############# #
@classmethod
def get_base_settings(cls, *args, **kwargs) -> dict:
"""`CustomData.get_settings` with `path_id="base"`."""
return cls.get_settings(*args, path_id="base", **kwargs)
@classmethod
def has_base_settings(cls, *args, **kwargs) -> bool:
"""`CustomData.has_settings` with `path_id="base"`."""
return cls.has_settings(*args, path_id="base", **kwargs)
@classmethod
def get_base_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.get_setting` with `path_id="base"`."""
return cls.get_setting(*args, path_id="base", **kwargs)
@classmethod
def has_base_setting(cls, *args, **kwargs) -> bool:
"""`CustomData.has_setting` with `path_id="base"`."""
return cls.has_setting(*args, path_id="base", **kwargs)
@classmethod
def resolve_base_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.resolve_setting` with `path_id="base"`."""
return cls.resolve_setting(*args, path_id="base", **kwargs)
@classmethod
def set_base_settings(cls, *args, **kwargs) -> None:
"""`CustomData.set_settings` with `path_id="base"`."""
cls.set_settings(*args, path_id="base", **kwargs)
# ############# Iteration ############# #
def items(
self,
over: str = "symbols",
group_by: tp.GroupByLike = None,
apply_group_by: bool = False,
keep_2d: bool = False,
key_as_index: bool = False,
) -> tp.Items:
"""Iterate over columns (or groups if grouped and `Wrapping.group_select` is True), keys,
features, or symbols. The respective mode can be selected with `over`.
See `vectorbtpro.base.wrapping.Wrapping.items` for iteration over columns.
Iteration over keys supports `group_by` but doesn't support `apply_group_by`."""
if (
over.lower() == "columns"
or (over.lower() == "symbols" and self.feature_oriented)
or (over.lower() == "features" and self.symbol_oriented)
):
for k, v in Analyzable.items(
self,
group_by=group_by,
apply_group_by=apply_group_by,
keep_2d=keep_2d,
key_as_index=key_as_index,
):
yield k, v
elif (
over.lower() == "keys"
or (over.lower() == "features" and self.feature_oriented)
or (over.lower() == "symbols" and self.symbol_oriented)
):
if apply_group_by:
raise ValueError("Cannot apply grouping to keys")
if group_by is not None:
key_wrapper = self.get_key_wrapper(group_by=group_by)
if key_wrapper.get_ndim() == 1:
if key_as_index:
yield key_wrapper.get_columns(), self
else:
yield key_wrapper.get_columns()[0], self
else:
for group, group_idxs in key_wrapper.grouper.iter_groups(key_as_index=key_as_index):
if keep_2d or len(group_idxs) > 1:
yield group, self.select_keys([self.keys[i] for i in group_idxs])
else:
yield group, self.select_keys(self.keys[group_idxs[0]])
else:
key_wrapper = self.get_key_wrapper(attach_classes=False)
if key_wrapper.ndim == 1:
if key_as_index:
yield key_wrapper.columns, self
else:
yield key_wrapper.columns[0], self
else:
for i in range(len(key_wrapper.columns)):
if key_as_index:
key = key_wrapper.columns[[i]]
else:
key = key_wrapper.columns[i]
if keep_2d:
yield key, self.select_keys([key])
else:
yield key, self.select_keys(key)
else:
raise ValueError(f"Invalid over: '{over}'")
# ############# Getting ############# #
def get_key_wrapper(
self,
keys: tp.Optional[tp.MaybeKeys] = None,
attach_classes: bool = True,
clean_index_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> ArrayWrapper:
"""Get wrapper with keys as columns.
If `attach_classes` is True, attaches `Data.classes` by stacking them over
the keys using `vectorbtpro.base.indexes.stack_indexes`.
Other keyword arguments are passed to the constructor of the wrapper."""
if clean_index_kwargs is None:
clean_index_kwargs = {}
if keys is None:
keys = self.keys
ndim = 1 if self.single_key else 2
else:
if self.has_multiple_keys(keys):
ndim = 2
else:
keys = [keys]
ndim = 1
for key in keys:
if self.feature_oriented:
self.assert_has_feature(key)
else:
self.assert_has_symbol(key)
new_columns = self.get_key_index(keys=keys)
wrapper = self.wrapper.replace(
columns=new_columns,
ndim=ndim,
grouper=None,
**kwargs,
)
if attach_classes:
classes = []
all_have_classes = True
for key in wrapper.columns:
if key in self.classes:
key_classes = self.classes[key]
if len(key_classes) > 0:
classes.append(key_classes)
else:
all_have_classes = False
else:
all_have_classes = False
if len(classes) > 0 and not all_have_classes:
if self.feature_oriented:
raise ValueError("Some features have classes while others not")
else:
raise ValueError("Some symbols have classes while others not")
if len(classes) > 0:
classes_frame = pd.DataFrame(classes)
if len(classes_frame.columns) == 1:
classes_columns = pd.Index(classes_frame.iloc[:, 0])
else:
classes_columns = pd.MultiIndex.from_frame(classes_frame)
new_columns = stack_indexes((classes_columns, wrapper.columns), **clean_index_kwargs)
wrapper = wrapper.replace(columns=new_columns)
if group_by is not None:
wrapper = wrapper.replace(group_by=group_by)
return wrapper
@cached_property
def key_wrapper(self) -> ArrayWrapper:
"""Key wrapper."""
return self.get_key_wrapper()
def get_feature_wrapper(self, features: tp.Optional[tp.MaybeFeatures] = None, **kwargs) -> ArrayWrapper:
"""Get wrapper with features as columns."""
if self.feature_oriented:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment