diff --git a/pyproject.toml b/pyproject.toml index 9df8241aad1..1bdc2634263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ extend-select = [ "PIE", # flake8-pie "TID", # flake8-tidy-imports (absolute imports) "PYI", # flake8-pyi + "SIM", # flake8-simplify "FLY", # flynt "I", # isort "PERF", # Perflint @@ -276,6 +277,9 @@ ignore = [ "PIE790", # unnecessary pass statement "PYI019", # use `Self` instead of custom TypeVar "PYI041", # use `float` instead of `int | float` + "SIM108", # use ternary operator instead of `if`-`else`-block + "SIM117", # use a single `with` statement instead of nested `with` statements + "SIM300", # yoda condition detected "PERF203", # try-except within a loop incurs performance overhead "E402", # module level import not at top of file "E731", # do not assign a lambda expression, use a def diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 8c3a01eba66..65937897a49 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -298,7 +298,7 @@ def _extract_nc4_variable_encoding( del encoding["chunksizes"] var_has_unlim_dim = any(dim in unlimited_dims for dim in variable.dims) - if not raise_on_invalid and var_has_unlim_dim and "contiguous" in encoding.keys(): + if not raise_on_invalid and var_has_unlim_dim and "contiguous" in encoding: del encoding["contiguous"] for k in safe_to_drop: diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 301ea430c4c..b348f77d941 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -348,14 +348,14 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]: if not g_fqn: g_fqn = {} groups = [ - store[key].id - for key in store.keys() - if isinstance(store[key], GroupType) + var.id for var in store.values() if isinstance(var, GroupType) ] for g in groups: g_fqn.update({g: path}) subgroups = [ - var for var in store[g] if isinstance(store[g][var], GroupType) + key + for key, var in store[g].items() + if isinstance(var, GroupType) ] if len(subgroups) > 0: npath = path + g diff --git a/xarray/coding/times.py b/xarray/coding/times.py index e6bc8ca59bd..bb46a456b6e 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -733,7 +733,7 @@ def infer_calendar_name(dates) -> CFCalendar: """Given an array of datetimes, infer the CF calendar name""" if is_np_datetime_like(dates.dtype): return "proleptic_gregorian" - elif dates.dtype == np.dtype("O") and dates.size > 0: + elif dates.dtype == np.dtype("O") and dates.size > 0: # noqa: SIM102 # Logic copied from core.common.contains_cftime_datetimes. if cftime is not None: sample = np.asarray(dates).flat[0] diff --git a/xarray/computation/rolling.py b/xarray/computation/rolling.py index 519d1f7eae6..ef324d88b6a 100644 --- a/xarray/computation/rolling.py +++ b/xarray/computation/rolling.py @@ -1081,7 +1081,7 @@ def __init__( self.side = side self.boundary = boundary - missing_dims = tuple(dim for dim in windows.keys() if dim not in self.obj.dims) + missing_dims = tuple(dim for dim in windows if dim not in self.obj.dims) if missing_dims: raise ValueError( f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " diff --git a/xarray/core/common.py b/xarray/core/common.py index 6181aa6a8c1..bfbf4f3ecd1 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1247,9 +1247,7 @@ def _dataset_indexer(dim: Hashable) -> DataArray: _dataarray_indexer if isinstance(cond, DataArray) else _dataset_indexer ) - indexers = {} - for dim in cond.sizes.keys(): - indexers[dim] = _get_indexer(dim) + indexers = {dim: _get_indexer(dim) for dim in cond.sizes} self = self.isel(**indexers) cond = cond.isel(**indexers) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c13d33872b6..5aa7a2970f9 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -212,7 +212,7 @@ def _check_data_shape( data_shape = tuple( ( as_variable(coords[k], k, auto_convert=False).size - if k in coords.keys() + if k in coords else 1 ) for k in dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 95e6f516403..34b6b4f71de 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4088,7 +4088,7 @@ def _rename( is raised at the right stack level. """ name_dict = either_dict_or_kwargs(name_dict, names, "rename") - for k in name_dict.keys(): + for k, new_k in name_dict.items(): if k not in self and k not in self.dims: raise ValueError( f"cannot rename {k!r} because it is not a " @@ -4096,13 +4096,12 @@ def _rename( ) create_dim_coord = False - new_k = name_dict[k] if k == new_k: continue # Same name, nothing to do if k in self.dims and new_k in self._coord_names: - coord_dims = self._variables[name_dict[k]].dims + coord_dims = self._variables[new_k].dims if coord_dims == (k,): create_dim_coord = True elif k in self._coord_names and new_k in self.dims: @@ -4112,7 +4111,7 @@ def _rename( if create_dim_coord: warnings.warn( - f"rename {k!r} to {name_dict[k]!r} does not create an index " + f"rename {k!r} to {new_k!r} does not create an index " "anymore. Try using swap_dims instead or use set_index " "after rename to create an indexed coordinate.", UserWarning, @@ -8980,16 +8979,18 @@ def pad( variables[name] = var elif name in self.data_vars: if utils.is_dict_like(constant_values): - if name in constant_values.keys(): + if name in constant_values: filtered_constant_values = constant_values[name] elif not set(var.dims).isdisjoint(constant_values.keys()): filtered_constant_values = { - k: v for k, v in constant_values.items() if k in var.dims + k: v # type: ignore[misc] + for k, v in constant_values.items() + if k in var.dims } else: filtered_constant_values = 0 # TODO: https://github.com/pydata/xarray/pull/9353#discussion_r1724018352 else: - filtered_constant_values = constant_values + filtered_constant_values = constant_values # type: ignore[assignment] variables[name] = var.pad( pad_width=var_pad_width, mode=mode, diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 14e70b0550c..94ff1e65ac0 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -989,9 +989,10 @@ def diff_array_repr(a, b, compat): ): summary.append(coords_diff) - if compat == "identical": - if attrs_diff := diff_attrs_repr(a.attrs, b.attrs, compat): - summary.append(attrs_diff) + if compat == "identical" and ( + attrs_diff := diff_attrs_repr(a.attrs, b.attrs, compat) + ): + summary.append(attrs_diff) return "\n".join(summary) @@ -1029,9 +1030,10 @@ def diff_dataset_repr(a, b, compat): ): summary.append(data_diff) - if compat == "identical": - if attrs_diff := diff_attrs_repr(a.attrs, b.attrs, compat): - summary.append(attrs_diff) + if compat == "identical" and ( + attrs_diff := diff_attrs_repr(a.attrs, b.attrs, compat) + ): + summary.append(attrs_diff) return "\n".join(summary) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1756fb54c1b..847c3a4edb9 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1247,7 +1247,7 @@ def create_variables( level = name dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok? - var = variables.get(name, None) + var = variables.get(name) if var is not None: attrs = var.attrs encoding = var.encoding diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8bb98118081..882fa817b60 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -624,7 +624,7 @@ def _wrapper( {**hlg.layers, **new_layers}, dependencies={ **hlg.dependencies, - **{name: {gname} for name in new_layers.keys()}, + **{name: {gname} for name in new_layers}, }, ) diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 210cea2c76a..a56c4117dfc 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -84,22 +84,17 @@ def __init__( self.freq = to_offset(freq) self.origin = origin - if isinstance(self.freq, MonthEnd | QuarterEnd | YearEnd): - if closed is None: - self.closed = "right" - else: - self.closed = closed - if label is None: - self.label = "right" - else: - self.label = label - # The backward resample sets ``closed`` to ``'right'`` by default - # since the last value should be considered as the edge point for - # the last bin. When origin in "end" or "end_day", the value for a - # specific ``cftime.datetime`` index stands for the resample result - # from the current ``cftime.datetime`` minus ``freq`` to the current - # ``cftime.datetime`` with a right close. - elif self.origin in ["end", "end_day"]: + if ( + isinstance(self.freq, MonthEnd | QuarterEnd | YearEnd) + or + # The backward resample sets ``closed`` to ``'right'`` by default + # since the last value should be considered as the edge point for + # the last bin. When origin in "end" or "end_day", the value for a + # specific ``cftime.datetime`` index stands for the resample result + # from the current ``cftime.datetime`` minus ``freq`` to the current + # ``cftime.datetime`` with a right close. + self.origin in ["end", "end_day"] + ): if closed is None: self.closed = "right" else: diff --git a/xarray/groupers.py b/xarray/groupers.py index 9ed948956a8..7d7b00ce3b1 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -705,25 +705,23 @@ def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: grouped = defaultdict(list) codes = defaultdict(list) seen: set[tuple[int, ...]] = set() - idx = 0 # This is quadratic, but the number of seasons is at most 12 for i, current in enumerate(season_inds): # Start with a group if current not in seen: - grouped[idx].append(current) - codes[idx].append(i) + grouped[i].append(current) + codes[i].append(i) seen.add(current) # Loop through remaining groups, and look for overlaps for j, second in enumerate(season_inds[i:]): - if not (set(chain(*grouped[idx])) & set(second)) and second not in seen: - grouped[idx].append(second) - codes[idx].append(j + i) + if not (set(chain(*grouped[i])) & set(second)) and second not in seen: + grouped[i].append(second) + codes[i].append(j + i) seen.add(second) if len(seen) == len(seasons): break - # found all non-overlapping groups for this row, increment and start over - idx += 1 + # found all non-overlapping groups for this row start over grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx) return [ diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 719b1fde619..0a190baef41 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -549,7 +549,7 @@ def map_plot1d( ) if add_legend: - use_legend_elements = not func.__name__ == "hist" + use_legend_elements = func.__name__ != "hist" if use_legend_elements: self.add_legend( use_legend_elements=use_legend_elements, diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index dd47df703b5..bac5ab8a075 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -419,9 +419,10 @@ def _infer_xy_labels( _assert_valid_xy(darray, x, "x") _assert_valid_xy(darray, y, "y") - if darray._indexes.get(x, 1) is darray._indexes.get(y, 2): - if isinstance(darray._indexes[x], PandasMultiIndex): - raise ValueError("x and y cannot be levels of the same MultiIndex") + if darray._indexes.get(x, 1) is darray._indexes.get(y, 2) and isinstance( + darray._indexes[x], PandasMultiIndex + ): + raise ValueError("x and y cannot be levels of the same MultiIndex") return x, y @@ -1824,7 +1825,7 @@ def _guess_coords_to_plot( """ coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None} available_coords = tuple( - k for k in darray.coords.keys() if k not in coords_to_plot_exist.values() + k for k in darray.coords if k not in coords_to_plot_exist.values() ) # If dims_plot[k] isn't defined then fill with one of the available dims, unless diff --git a/xarray/structure/merge.py b/xarray/structure/merge.py index 403186272b9..5c998075151 100644 --- a/xarray/structure/merge.py +++ b/xarray/structure/merge.py @@ -300,7 +300,7 @@ def merge_collected( variables = [variable for variable, _ in elements_list] try: merged_vars[name] = unique_variable( - name, variables, compat, equals.get(name, None) + name, variables, compat, equals.get(name) ) except MergeError: if compat != "minimal": diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index fe4c1684cbd..4211968c74b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2671,13 +2671,13 @@ def test_hidden_zarr_keys(self) -> None: # check that a variable hidden attribute is present and correct # JSON only has a single array type, which maps to list in Python. # In contrast, dims in xarray is always a tuple. - for var in expected.variables.keys(): + for var in expected.variables: dims = zarr_group[var].attrs[self.DIMENSION_KEY] assert dims == list(expected[var].dims) with xr.decode_cf(store): # make sure it is hidden - for var in expected.variables.keys(): + for var in expected.variables: assert self.DIMENSION_KEY not in expected[var].attrs # put it back and try removing from a variable @@ -3731,7 +3731,7 @@ def test_chunk_key_encoding_v2(self) -> None: # Verify the chunk keys in store use the slash separator if not has_zarr_v3: - chunk_keys = [k for k in store.keys() if k.startswith("var1/")] + chunk_keys = [k for k in store if k.startswith("var1/")] assert len(chunk_keys) > 0 for key in chunk_keys: assert "/" in key diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 6b3674e1a8c..2d11c96ebb6 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -66,7 +66,7 @@ def assert_chunks_equal( and node1.variables[name].chunksizes == node2.variables[name].chunksizes ) for path, (node1, node2) in xr.group_subtrees(actual, expected) - for name in node1.variables.keys() + for name in node1.variables } assert all(comparison.values()), diff_chunks(comparison, actual, expected) @@ -312,9 +312,9 @@ def test_open_groups(self, unaligned_datatree_nc) -> None: unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc) # Check that group names are keys in the dictionary of `xr.Datasets` - assert "/" in unaligned_dict_of_datasets.keys() - assert "/Group1" in unaligned_dict_of_datasets.keys() - assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys() + assert "/" in unaligned_dict_of_datasets + assert "/Group1" in unaligned_dict_of_datasets + assert "/Group1/subgroup1" in unaligned_dict_of_datasets # Check that group name returns the correct datasets with xr.open_dataset(unaligned_datatree_nc, group="/") as expected: assert_identical(unaligned_dict_of_datasets["/"], expected) @@ -453,9 +453,9 @@ def test_open_groups(self, url=unaligned_datatree_url) -> None: unaligned_dict_of_datasets = open_groups(url, engine=self.engine) # Check that group names are keys in the dictionary of `xr.Datasets` - assert "/" in unaligned_dict_of_datasets.keys() - assert "/Group1" in unaligned_dict_of_datasets.keys() - assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys() + assert "/" in unaligned_dict_of_datasets + assert "/Group1" in unaligned_dict_of_datasets + assert "/Group1/subgroup1" in unaligned_dict_of_datasets # Check that group name returns the correct datasets with xr.open_dataset(url, engine=self.engine, group="/") as expected: assert_identical(unaligned_dict_of_datasets["/"], expected) @@ -782,10 +782,10 @@ def test_open_groups(self, unaligned_datatree_zarr_factory, zarr_format) -> None storepath = unaligned_datatree_zarr_factory(zarr_format=zarr_format) unaligned_dict_of_datasets = open_groups(storepath, engine="zarr") - assert "/" in unaligned_dict_of_datasets.keys() - assert "/Group1" in unaligned_dict_of_datasets.keys() - assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys() - assert "/Group2" in unaligned_dict_of_datasets.keys() + assert "/" in unaligned_dict_of_datasets + assert "/Group1" in unaligned_dict_of_datasets + assert "/Group1/subgroup1" in unaligned_dict_of_datasets + assert "/Group2" in unaligned_dict_of_datasets # Check that group name returns the correct datasets with xr.open_dataset(storepath, group="/", engine="zarr") as expected: assert_identical(unaligned_dict_of_datasets["/"], expected) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index acb32504948..24ef2c8397e 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -94,8 +94,8 @@ def test_coder_roundtrip() -> None: assert_identical(original, roundtripped) -@pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split()) -@pytest.mark.parametrize("dtype2", "f4 f8".split()) +@pytest.mark.parametrize("dtype", ["u1", "u2", "i1", "i2", "f2", "f4"]) +@pytest.mark.parametrize("dtype2", ["f4", "f8"]) def test_scaling_converts_to_float(dtype: str, dtype2: str) -> None: dt = np.dtype(dtype2) original = xr.Variable( diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 40d63ed6981..d463f2edd91 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -29,8 +29,8 @@ def assert_combined_tile_ids_equal(dict1, dict2): assert len(dict1) == len(dict2) - for k in dict1.keys(): - assert k in dict2.keys() + for k in dict1: + assert k in dict2 assert_equal(dict1[k], dict2[k]) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index eefa3c2b4f8..ccb832ee522 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1636,7 +1636,7 @@ def test_normalize_token_with_backend(map_ds): with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp_file: map_ds.to_netcdf(tmp_file) read = xr.open_dataset(tmp_file) - assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read) + assert dask.base.tokenize(map_ds) != dask.base.tokenize(read) read.close() diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3e0734c8a1a..ab2e673fcb6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3164,7 +3164,7 @@ def test_drop_encoding(self) -> None: vencoding = {"scale_factor": 10} orig.encoding = {"foo": "bar"} - for k in orig.variables.keys(): + for k in orig.variables: orig[k].encoding = vencoding actual = orig.drop_encoding() diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ac90d216144..b3e223659b2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3235,7 +3235,7 @@ def test_shuffle_simple() -> None: da = xr.DataArray( dims="x", data=dask.array.from_array([1, 2, 3, 4, 5, 6], chunks=2), - coords={"label": ("x", "a b c a b c".split(" "))}, + coords={"label": ("x", ["a", "b", "c", "a", "b", "c"])}, ) actual = da.groupby(label=UniqueGrouper()).shuffle_to_chunks() expected = da.isel(x=[0, 3, 1, 4, 2, 5]) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 48819333ca2..699e3df9769 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -68,7 +68,7 @@ def test_number_of_dims(self, data, ndims): def test_restrict_names(self, data): capitalized_names = st.text(st.characters(), min_size=1).map(str.upper) dim_sizes = data.draw(dimension_sizes(dim_names=capitalized_names)) - for dim in dim_sizes.keys(): + for dim in dim_sizes: assert dim.upper() == dim diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index ab4ec36ea97..cce94ed9150 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -234,9 +234,7 @@ def convert_units(obj, to): elif isinstance(obj, xr.DataArray): name = obj.name - new_units = ( - to.get(name, None) or to.get("data", None) or to.get(None, None) or None - ) + new_units = to.get(name) or to.get("data") or to.get(None) or None data = convert_units(obj.variable, {None: new_units}) coords = { @@ -3052,7 +3050,7 @@ def is_compatible(a, b): other_units = extract_units(other) equal_arrays = all( - is_compatible(units[name], other_units[name]) for name in units.keys() + is_compatible(units[name], other_units[name]) for name in units ) and ( strip_units(data_array).equals( strip_units(convert_units(other, extract_units(data_array))) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 8f3be73ee68..0be9713c3dc 100644 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -20,7 +20,7 @@ def get_sys_info(): if os.path.isdir(".git") and os.path.isdir("xarray"): try: pipe = subprocess.Popen( - 'git log --format="%H" -n 1'.split(" "), + ("git", "log", '--format="%H"', "-n", "1"), stdout=subprocess.PIPE, stderr=subprocess.PIPE, )