-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Improve the performance of temporal group averaging #689
Conversation
7594df4
to
0d56ed5
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #689 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 15 15
Lines 1544 1546 +2
=========================================
+ Hits 1544 1546 +2 ☔ View full report in Codecov by Sentry. |
Replace `.load()` with `.astype("timedelta64[ns"])` for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial self-review. The GH Actions build is passing.
xcdat/temporal.py
Outdated
# 5. Calculate the departures for the data variable. | ||
# ---------------------------------------------------------------------- | ||
# This step allows us to perform xarray's grouped arithmetic to | ||
# calculate departures. | ||
dv_obs = ds_obs[data_var].copy() | ||
self._labeled_time = self._label_time_coords(dv_obs[self.dim]) | ||
dv_obs_grouped = self._group_data(dv_obs) | ||
|
||
# 5. Align time dimension names using the labeled time dimension name. | ||
# ---------------------------------------------------------------------- | ||
# The climatology's time dimension is renamed to the labeled time | ||
# dimension in step #4 above (e.g., "time" -> "season"). xarray requires | ||
# dimension names to be aligned to perform grouped arithmetic, which we | ||
# use for calculating departures in step #5. Otherwise, this error is | ||
# raised: "`ValueError: incompatible dimensions for a grouped binary | ||
# operation: the group variable '<FREQ ARG>' is not a dimension on the | ||
# other argument`". | ||
dv_climo = ds_climo[data_var] | ||
dv_climo = dv_climo.rename({self.dim: self._labeled_time.name}) | ||
|
||
# 6. Calculate the departures for the data variable. | ||
# ---------------------------------------------------------------------- | ||
# departures = observation - climatology | ||
with xr.set_options(keep_attrs=True): | ||
dv_departs = dv_obs_grouped - dv_climo | ||
dv_departs = self._add_operation_attrs(dv_departs) | ||
ds_obs[data_var] = dv_departs | ||
ds_departs = self._calculate_departures(ds_obs, ds_climo, data_var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored this block of code into self._calculate_departures()
for readability.
self._labeled_time = self._label_time_coords(dv[self.dim]) | ||
dv = dv.assign_coords({self.dim: self._labeled_time}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address bottleneck #1 from PR description.
replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)
@@ -1285,19 +1248,14 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray: | |||
# or time unit (with rare exceptions see release notes). To avoid this | |||
# warning please use the scalar types `np.float64`, or string notation.` | |||
if isinstance(time_lengths.data, Array): | |||
time_lengths.load() | |||
time_lengths = time_lengths.astype("timedelta64[ns]") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address bottleneck #2 from PR description
dv = dv.assign_coords({self.dim: self._labeled_time}) | ||
dv_gb = dv.groupby(self.dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address bottleneck #1 from PR description
replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)
time_grouped = xr.DataArray( | ||
name="_".join(df_dt_components.columns), | ||
name=self.dim, | ||
data=dt_objects, | ||
coords={self.dim: time_coords[self.dim]}, | ||
coords={self.dim: dt_objects}, | ||
dims=[self.dim], | ||
attrs=time_coords[self.dim].attrs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address bottleneck #1 from PR description
xcdat/temporal.py
Outdated
if self._mode in ["group_average", "climatology"]: | ||
self._weights = self._weights.rename({self.dim: f"{self.dim}_original"}) | ||
# Only keep the original time coordinates, not the ones labeled | ||
# by group. | ||
self._weights = self._weights.drop_vars(self._labeled_time.name) | ||
weights = self._weights.assign_coords({self.dim: self._dataset[self.dim]}) | ||
weights = weights.rename({self.dim: f"{self.dim}_original"}) | ||
|
||
ds[self._weights.name] = self._weights | ||
ds[weights.name] = weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reassign the original, unlabeled time coordinates back to the weights
xr.DataArray and then rename it to "time_original"
to avoid conflicting the the labeled time coordinates (now called "time"
).
dv_departs = dv_departs.assign_coords({self.dim: ds_obs[self.dim]}) | ||
ds_departs[data_var] = dv_departs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reassign the grouped, unlabeled time coordinates back to the final departures time coordinates (since the labeled, grouped time coordinates sometimes removes the year of the time coordinates).
Hi @chengzhuzhang, this PR is ready for review. After refactoring, I managed to cut down the runtime as following:
I also performed a regression test using the same e3sm_diags dataset between Benchmarking Script# %%
import xarray as xr
import xcdat as xc
### 1. Using temporal.climatology from xcdat
file_path = "/global/cfs/cdirs/e3sm/e3sm_diags/postprocessed_e3sm_v2_data_for_e3sm_diags/20221103.v2.LR.amip.NGD_v3atm.chrysalis/arm-diags-data/PRECT_sgpc1_198501_201412.nc"
ds = xc.open_dataset(file_path)
branch = "dev"
# %%
# 1. Calculate annual climatology
# -------------------------------
ds_annual_cycle = ds.temporal.climatology("PRECT", "month", keep_weights=True)
ds_annual_cycle.to_netcdf(f"temporal_climatology_{branch}.nc")
"""
main
--------------------------
CPU times: user 33 s, sys: 2.41 s, total: 35.4 s
Wall time: 35.4 s
refactor/688-temp-api-perf
--------------------------
CPU times: user 5.85 s, sys: 2.88 s, total: 8.72 s
Wall time: 8.78 s
"""
# %%
# 2. Calculate annual departures
# ------------------------------
ds_annual_cycle_anom = ds.temporal.departures("PRECT", "month", keep_weights=True)
ds_annual_cycle_anom.to_netcdf(f"temporal_departures_{branch}.nc")
"""
main
--------------------------
CPU times: user 1min 9s, sys: 4.8 s, total: 1min 14s
Wall time: 1min 14s
refactor/688-temp-api-perf
--------------------------
CPU times: user 11.6 s, sys: 4.32 s, total: 15.9 s
Wall time: 15.9 s
"""
# %%
# 3. Calculate monthly group averages
# -----------------------------------
ds_annual_avg = ds.temporal.group_average("PRECT", "month", keep_weights=True)
ds_annual_avg.to_netcdf(f"temporal_group_average_{branch}.nc")
"""
main
--------------------------
CPU times: user 33.5 s, sys: 2.27 s, total: 35.8 s
Wall time: 35.9 s
refactor/688-temp-api-perf
--------------------------
CPU times: user 5.59 s, sys: 2.06 s, total: 7.65 s
Wall time: 7.65 s
""" Regression testing scriptimport glob
import xarray as xr
# Get the filepaths for the dev and main branches
dev_filepaths = sorted(glob.glob("qa/issue-688/dev/*.nc"))
main_filepaths = sorted(glob.glob("qa/issue-688/main/*.nc"))
for fp, mp in zip(dev_filepaths, main_filepaths):
print(f"Comparing {fp} and {mp}")
# Load the datasets
dev_ds = xr.open_dataset(fp)
main_ds = xr.open_dataset(mp)
# Compare the datasets
try:
xr.testing.assert_identical(dev_ds, main_ds)
except AssertionError as e:
print(f"Datasets are not identical: {e}")
else:
print("Datasets are identical") Next step
|
|
||
time_lengths = time_lengths.astype(np.float64) | ||
|
||
grouped_time_lengths = self._group_data(time_lengths) | ||
weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum() | ||
weights.name = f"{self.dim}_wts" | ||
|
||
# Validate the sum of weights for each group is 1.0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be a good feature to have to check if the sum matches. But if it de-gradates the performance a lot, we can exclude it. maybe this check can be just implemented in testing (if it is not included yet). Also the _get_weights
description needs to be updated to reflect that sum is no longer validated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expect the logic of _get_weights()
to be correct, so this assertion should not be necessary at runtime (especially with the performance hit).
I like your suggestion of making it a unit test instead. I will push a commit with this change soon.
xcdat/temporal.py
Outdated
if weighted and keep_weights: | ||
self._weights = ds_climo.time_wts | ||
ds_obs = self._keep_weights(ds_obs) | ||
if keep_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice this if statement changed from if weighted and keep_weights
, should it be kept the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for catching this. I reverted the conditional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Tom, Thank you for the PR! I think it looks great, just have minor comments for you to consider.
- Check if sum of each weight group equals 1.0 - Update `_get_weights()` docs to remove validation portion
Description
TODO:
_get_weights()
, loading time lengths into memory is slow (lines) -- replace with casting to"timedelta64[ns]"
thenfloat64
_get_weights()
, performing validation to check the sums of weights for each group adds up to 1 is slow (lines) -- remove this unnecessary assertionIdentify performance optimizations -- I don't think this is necessary right nowgroupby
with vs. withoutflox
packagemain
Checklist
If applicable: