Daniel Busch 2 nedēļas atpakaļ
vecāks
revīzija
7ae0f229eb

+ 1 - 2
src/unfccc_ghg_data/unfccc_crf_reader/crf_raw_for_year_sparse_arrays.py

@@ -368,7 +368,7 @@ def crf_raw_for_year_pandas(
 # wie würde man über Länder summieren wenn die Länder im dt sind
 
 
-def crf_raw_for_year_datatree(  # noqa: PLR0912
+def crf_raw_for_year_datatree(
     submission_year: int,
     output_folder: Path,
     submission_type: str = "CRF",
@@ -447,4 +447,3 @@ def crf_raw_for_year_datatree(  # noqa: PLR0912
     dt = xr.DataTree(name="All", children=countries)
 
     return dt
-

+ 1 - 1
tests/integration/test_crf_for_year.py

@@ -26,7 +26,7 @@ def test_crf_for_year_original_version(tmp_path):
 
 def test_crf_for_year_sparse_arrays(tmp_path):
     start = timer()
-    n_countries = 20 # 100 will find 26 countries
+    n_countries = 20  # 100 will find 26 countries
     crf_raw_for_year_sparse_arrays(
         submission_year=1,
         submission_type="CRT",

+ 78 - 31
tests/integration/test_datatree_for_crt_crf.py

@@ -1,11 +1,12 @@
+import numpy as np
+import primap2 as pm
 import pytest
+import xarray as xr
 
 from src.unfccc_ghg_data.unfccc_crf_reader.crf_raw_for_year_sparse_arrays import (
-    crf_raw_for_year_datatree
+    crf_raw_for_year_datatree,
 )
-import primap2 as pm
-import xarray as xr
-import numpy as np
+
 
 @pytest.fixture
 def crt_dt(tmp_path):
@@ -20,76 +21,120 @@ def crt_dt(tmp_path):
 
     return crt_dt
 
-def test_country_aggregation(crt_dt):
 
+def test_country_aggregation(crt_dt):
     # Nigeria (NGA)
     expected_NGA = 123840.3389332373
-    assert np.squeeze(crt_dt["/NGA"]["CO2"].sel(
-        {"category (CRT1)" : "1", "class" : "Total", "time" : "2022-01-01"}).to_numpy()).item() == expected_NGA
+    assert (
+        np.squeeze(
+            crt_dt["/NGA"]["CO2"]
+            .sel({"category (CRT1)": "1", "class": "Total", "time": "2022-01-01"})
+            .to_numpy()
+        ).item()
+        == expected_NGA
+    )
     # NGA has data from 2000 to 2022
     assert len(crt_dt["/NGA"].time) == 23
 
     # Georgia (GEO)
     expected_GEO = 10954.899111969342
-    assert np.squeeze(crt_dt["/GEO"]["CO2"].sel(
-        {"category (CRT1)" : "1", "class" : "Total", "time" : "2022-01-01"}).to_numpy()).item() == expected_GEO
+    assert (
+        np.squeeze(
+            crt_dt["/GEO"]["CO2"]
+            .sel({"category (CRT1)": "1", "class": "Total", "time": "2022-01-01"})
+            .to_numpy()
+        ).item()
+        == expected_GEO
+    )
     # GEO has data from 1990 to 2022
     assert len(crt_dt["/GEO"].time) == 33
     # we have a value for 1990
-    assert np.squeeze(crt_dt["/GEO"]["CO2"].sel(
-        {"category (CRT1)" : "1", "class" : "Total", "time" : "1990-01-01"}).to_numpy()).item()
+    assert np.squeeze(
+        crt_dt["/GEO"]["CO2"]
+        .sel({"category (CRT1)": "1", "class": "Total", "time": "1990-01-01"})
+        .to_numpy()
+    ).item()
 
     # Remove country dimension and add
-    crt_dt["/NGA_GEO"] = crt_dt["/NGA"].to_dataset().pr.loc[{"area (ISO3)" : "NGA"}] + crt_dt["/GEO"].to_dataset().pr.loc[{"area (ISO3)" : "GEO"}]
+    crt_dt["/NGA_GEO"] = (
+        crt_dt["/NGA"].to_dataset().pr.loc[{"area (ISO3)": "NGA"}]
+        + crt_dt["/GEO"].to_dataset().pr.loc[{"area (ISO3)": "GEO"}]
+    )
 
     # Check if the sum of NGA and GEO is equal to NGA_GEO
     # This will perform an inner join, NGA has 23 time steps, GEO has 33 time steps,
     # and the result will have 23 time steps
-    assert np.squeeze(crt_dt["/NGA_GEO"]["CO2"].sel(
-        {"category (CRT1)" : "1", "class" : "Total", "time" : "2022-01-01"}).to_numpy()).item() == expected_NGA + expected_GEO
+    assert (
+        np.squeeze(
+            crt_dt["/NGA_GEO"]["CO2"]
+            .sel({"category (CRT1)": "1", "class": "Total", "time": "2022-01-01"})
+            .to_numpy()
+        ).item()
+        == expected_NGA + expected_GEO
+    )
     assert len(crt_dt["/NGA_GEO"].time) == 23
 
     # The default option for arithmetic operations is to perform an inner join
     # set options to perform outer join
     with xr.set_options(arithmetic_join="outer"):
-        crt_dt["/NGA_GEO_outer"] = crt_dt["/NGA"].to_dataset().pr.loc[{"area (ISO3)" : "NGA"}] + crt_dt["/GEO"].to_dataset().pr.loc[{"area (ISO3)" : "GEO"}]
+        crt_dt["/NGA_GEO_outer"] = (
+            crt_dt["/NGA"].to_dataset().pr.loc[{"area (ISO3)": "NGA"}]
+            + crt_dt["/GEO"].to_dataset().pr.loc[{"area (ISO3)": "GEO"}]
+        )
 
     # Where there is overlap between NGA and GEO, the values will be added
-    assert np.squeeze(crt_dt["/NGA_GEO_outer"]["CO2"].sel(
-        {"category (CRT1)" : "1", "class" : "Total", "time" : "2022-01-01"}).to_numpy()).item() == expected_NGA + expected_GEO
+    assert (
+        np.squeeze(
+            crt_dt["/NGA_GEO_outer"]["CO2"]
+            .sel({"category (CRT1)": "1", "class": "Total", "time": "2022-01-01"})
+            .to_numpy()
+        ).item()
+        == expected_NGA + expected_GEO
+    )
     # When only one time series has a value, the value will be set to nan
     # TODO Is that the behaviour we need?
-    assert np.isnan(np.squeeze(crt_dt["/NGA_GEO_outer"]["CO2"].sel(
-        {"category (CRT1)" : "1", "class" : "Total", "time" : "1990-01-01"}).to_numpy()).item())
+    assert np.isnan(
+        np.squeeze(
+            crt_dt["/NGA_GEO_outer"]["CO2"]
+            .sel({"category (CRT1)": "1", "class": "Total", "time": "1990-01-01"})
+            .to_numpy()
+        ).item()
+    )
     # Extra time steps will be added
     assert len(crt_dt["/NGA_GEO_outer"].time) == 33
 
     # How can we use data tree methods for this?
 
     # filter countries we need
-    dt_geo_nga = crt_dt.filter(lambda x: x.name in ["NGA", "GEO"])
-
-    # fails because, x will be a data set view object that does not have the attribute to_dataset
-    # dt_geo_nga = dt_geo_nga.map_over_datasets(lambda x: x.to_dataset().pr.loc[{"area (ISO3)" : x.name}])
+    dt_geo_nga = crt_dt.filter(lambda x: x.name in ["NGA", "GEO"])  # noqa: F841
 
+    # fails because, x will be a data set view object that does
+    # not have the attribute to_dataset
+    # dt_geo_nga = dt_geo_nga.map_over_datasets(
+    # lambda x: x.to_dataset().pr.loc[{"area (ISO3)" : x.name}])
 
 
 def test_crf_for_year_original_version(crt_dt):
-
     print(crt_dt.groups)
     # output:
-    # ('/', '/DZA', '/ARG', '/AZE', '/BTN', '/BRA', '/BRN', '/CHL', '/CHN', '/COL', '/CIV', '/ECU', '/EGY', '/GEO',
-    # '/GHA', '/GNB', '/GUY', '/IDN', '/KEN', '/LBN', '/MYS', '/MDV', '/MUS', '/MAR', '/NAM', '/NGA')
+    # ('/', '/DZA', '/ARG', '/AZE', '/BTN', '/BRA', '/BRN', '/CHL', '/CHN',
+    # '/COL', '/CIV', '/ECU', '/EGY', '/GEO',
+    # '/GHA', '/GNB', '/GUY', '/IDN', '/KEN', '/LBN', '/MYS', '/MDV',
+    # '/MUS', '/MAR', '/NAM', '/NGA')
     # out of the 100 countries, 26 have CRT data
 
     # to make it more interesting, we also add primap-hist to the data tree
-    primap_hist = pm.open_dataset("../../data/Guetschow_et_al_2025-PRIMAP-hist_v2.6.1_final_13-Mar-2025.nc")
+    primap_hist = pm.open_dataset(
+        "../../data/Guetschow_et_al_2025-PRIMAP-hist_v2.6.1_final_13-Mar-2025.nc"
+    )
     # Assign primap hist to the data tree
     crt_dt["primap_hist"] = primap_hist
     print(crt_dt.groups)
     # output:
-    # ('/', '/DZA', '/ARG', '/AZE', '/BTN', '/BRA', '/BRN', '/CHL', '/CHN', '/COL', '/CIV', '/ECU', '/EGY', '/GEO',
-    # '/GHA', '/GNB', '/GUY', '/IDN', '/KEN', '/LBN', '/MYS', '/MDV', '/MUS', '/MAR', '/NAM', '/NGA', '/primap_hist')
+    # ('/', '/DZA', '/ARG', '/AZE', '/BTN', '/BRA', '/BRN', '/CHL',
+    # '/CHN', '/COL', '/CIV', '/ECU', '/EGY', '/GEO',
+    # '/GHA', '/GNB', '/GUY', '/IDN', '/KEN', '/LBN', '/MYS', '/MDV',
+    # '/MUS', '/MAR', '/NAM', '/NGA', '/primap_hist')
 
     # try some queries
 
@@ -101,7 +146,9 @@ def test_crf_for_year_original_version(crt_dt):
     # print(dt_crt.sel({"category (IPCC2006_PRIMAP)" : "1"}))
 
     # we can re-assign the primap hist data set with a renamed category dimension
-    crt_dt["primap_hist"] = primap_hist.rename({"category (IPCC2006_PRIMAP)" : "category (CRT1)"})
+    crt_dt["primap_hist"] = primap_hist.rename(
+        {"category (IPCC2006_PRIMAP)": "category (CRT1)"}
+    )
 
     # Now we should be able to filter for categories over all nodes
-    dt_cat1 = crt_dt.sel({"category (CRT1)" : "1"})
+    dt_cat1 = crt_dt.sel({"category (CRT1)": "1"})  # noqa: F841