Skip to content

Switch to T_DataArray and T_Dataset in concat #6784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 18, 2022
27 changes: 14 additions & 13 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
merge_attrs,
merge_collected,
)
from .types import T_DataArray, T_Dataset
from .variable import Variable
from .variable import concat as concat_vars

Expand All @@ -25,31 +26,31 @@

@overload
def concat(
objs: Iterable[Dataset],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_Dataset],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
...


@overload
def concat(
objs: Iterable[DataArray],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_DataArray],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
...


Expand Down Expand Up @@ -402,7 +403,7 @@ def process_subset_opt(opt, subset):

# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(
datasets: Iterable[Dataset],
datasets: Iterable[T_Dataset],
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:

dims: set[Hashable] = set()
Expand All @@ -429,16 +430,16 @@ def _parse_datasets(


def _dataset_concat(
datasets: list[Dataset],
dim: str | DataArray | pd.Index,
datasets: list[T_Dataset],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand Down Expand Up @@ -618,16 +619,16 @@ def get_indexes(name):


def _dataarray_concat(
arrays: Iterable[DataArray],
dim: str | DataArray | pd.Index,
arrays: Iterable[T_DataArray],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
from .dataarray import DataArray

arrays = list(arrays)
Expand Down