1+ from collections import defaultdict
2+ import xarray as xr
3+ import numpy as np
4+
5+ def get_1d_dims (d ):
6+ """
7+ Find all dimensions in a dataset that are purely 1-dimensional,
8+ i.e., those dimensions that are not part of a 2D or higher-D
9+ variable.
10+
11+ arguments
12+ d: xarray Dataset
13+ returns
14+ dims1d: a list of dimension names
15+ """
16+ # Assume all dims coorespond to 1D vars
17+ dims1d = list (d .dims .keys ())
18+ for varname , var in d .variables .items ():
19+ if len (var .dims ) > 1 :
20+ for vardim in var .dims :
21+ if vardim in dims1d :
22+ dims1d .remove (str (vardim ))
23+ return dims1d
24+
25+ def gen_1d_datasets (d ):
26+ """
27+ Generate a sequence of datasets having only those variables
28+ along each dimension that is only used for 1-dimensional variables.
29+
30+ arguments
31+ d: xarray Dataset
32+ returns
33+ generator function yielding a sequence of single-dimension datasets
34+ """
35+ dims1d = get_1d_dims (d )
36+ # print(dims1d)
37+ for dim in dims1d :
38+ all_dims = list (d .dims .keys ())
39+ all_dims .remove (dim )
40+ yield d .drop_dims (all_dims )
41+
42+ def get_1d_datasets (d ):
43+ """
44+ Generate a list of datasets having only those variables
45+ along each dimension that is only used for 1-dimensional variables.
46+
47+ arguments
48+ d: xarray Dataset
49+ returns
50+ a list of single-dimension datasets
51+ """
52+ return [d1 for d1 in gen_1d_datasets (d , * args , ** kwargs )]
53+
54+ def get_scalar_vars (d ):
55+ scalars = []
56+ for varname , var in d .variables .items ():
57+ if len (var .dims ) == 0 :
58+ scalars .append (varname )
59+ return scalars
60+
61+ def concat_1d_dims (datasets , stack_scalars = None ):
62+ """
63+ For each xarray Dataset in datasets, concatenate (preserving the order of datasets)
64+ all variables along dimensions that are only used for one-dimensional variables.
65+
66+ arguments
67+ d: iterable of xarray Datasets
68+ stack_scalars: create a new dimension named with this value
69+ that aggregates all scalar variables and coordinates
70+ returns
71+ a new xarray Dataset with only the single-dimension variables
72+ """
73+ # dictionary mapping dimension names to a list of all
74+ # datasets having only that dimension
75+ all_1d_datasets = defaultdict (list )
76+
77+ for d in datasets :
78+ scalars = get_scalar_vars (d )
79+ for d_1d_initial in gen_1d_datasets (d ):
80+ # Get rid of scalars
81+ d_1d = d_1d_initial .drop (scalars )
82+ dims = tuple (d_1d .dims .keys ())
83+ all_1d_datasets [dims [0 ]].append (d_1d )
84+ if stack_scalars :
85+ # restore scalars along new dimension stack_scalars
86+ scalar_dataset = xr .Dataset ()
87+ for scalar_var in scalars :
88+ # promote from scalar to an array with a dimension, and remove
89+ # the coordinate info so that it's just a regular variable.
90+ as_1d = d [scalar_var ].expand_dims (stack_scalars ).reset_coords (drop = True )
91+ scalar_dataset [scalar_var ] = as_1d # xr.DataArray(as_1d, dims=[stack_scalars])
92+ all_1d_datasets [stack_scalars ].append (scalar_dataset )
93+
94+ unified = xr .Dataset ()
95+ for dim in all_1d_datasets :
96+ combined = xr .concat (all_1d_datasets [dim ], dim , coords = 'minimal' , data_vars = 'minimal' )
97+ unified .update (combined )
98+ return unified
99+
100+ # datasets=[]
101+ # for i, size in enumerate((4, 6)):
102+ # a = xr.DataArray(10*i + np.arange(size), dims='x')
103+ # b = xr.DataArray(10*i + np.arange(size/2), dims='y')
104+ # c = xr.DataArray(20*i + np.arange(size*3), dims='t')
105+ # d = xr.DataArray(11*i + np.arange(size*3), dims='t')
106+ # T = xr.DataArray(10*i + np.arange(size)**2, dims='x')
107+ # D = xr.DataArray(10*i + np.arange(size/2)**2, dims='y')
108+ # z = xr.DataArray(10*i + np.arange(size*4)**2, dims='z')
109+ # u = xr.DataArray(10*i + np.arange(size*5)**2, dims='u')
110+ # v = xr.DataArray(12*i + np.arange(size*5)**2, dims='u')
111+ # P = xr.DataArray(10*i + np.ones((size,int(size/2))), dims=['x', 'y'])
112+ # Q = xr.DataArray(20*i + np.ones((size,int(size/2))), dims=['x', 'y'])
113+ # d = xr.Dataset({'x':a,'y':b, 't':c, 'd':d, 'u':u, 'v':v, 'z':z, 'T':T, 'D':D, 'P':P, 'Q':Q})
114+ # datasets.append(d)
115+ # # datasets.append(d[{'x':slice(None, None), 'y':slice(0,0)}])
116+ # for d in datasets: print(d,'\n')
117+ # # xr.combine_by_coords(datasets, coords='all')
118+ # # xr.combine_nested(datasets, coords='all', data_vars='all')
119+
120+ # # print(get_1d_dims(d))
121+ # assert(get_1d_dims(d)==['t', 'u', 'z'])
122+ # # for d1 in get_1d_datasets(d):
123+ # # print(d1,'\n')
124+
125+ # combined = concat_1d_dims(datasets)
126+ # print(combined)
0 commit comments