adjoint.py 16.6 KB
Newer Older
1
2
import numpy as np
import xarray as xr
3
import itertools
4

5
from logging import info, debug
6
from .....utils.datastores.dump import read_datastore
7
8
from .apply_AK import apply_ak_ad
from .vinterp import vertical_interp
9
import copy
10
11
12
13


def adjoint(
        transf,
14
        inout_datastore,
15
16
17
18
19
20
21
22
23
24
25
26
27
        controlvect,
        obsvect,
        mapper,
        di,
        df,
        mode,
        runsubdir,
        workdir,
        onlyinit=False,
        **kwargs
):
    """De-aggregate total columns to the model level."""
    ddi = min(di, df)
28
29
30
31

    ref_parameter = transf.parameter[0]
    ref_component = transf.component[0]

32
    ref_ds = inout_datastore["outputs"][(ref_component, ref_parameter)][ddi]
33
    out_datastore = inout_datastore["inputs"]
Antoine Berchet's avatar
Antoine Berchet committed
34

35
    # Copy data to output datastore
36
37
38
    input_components = ["concs", "pressure", "dpressure"] \
                       + (transf.product == "column") * ["airm", "hlay"]
    for incomp in input_components:
Antoine Berchet's avatar
test    
Antoine Berchet committed
39
        out_datastore[(incomp, ref_parameter)][ddi] = \
40
            copy.deepcopy(ref_ds)
41
        if incomp != "concs":
Antoine Berchet's avatar
Antoine Berchet committed
42
43
            out_datastore[(incomp, ref_parameter)][ddi].loc[
                :, ("maindata", "adj_out")] = 0
44

Antoine Berchet's avatar
Antoine Berchet committed
45
    y0 = out_datastore[("concs", ref_parameter)][ddi]
Antoine Berchet's avatar
Antoine Berchet committed
46

47
48
49
    # Exit if empty observations
    if len(y0) == 0:
        return
Antoine Berchet's avatar
Antoine Berchet committed
50

51
52
53
54
    # Number of levels to extract for satellites
    dlev = np.ones(len(y0), dtype=int) * transf.model.domain.nlev

    # Index in the original data of the level-extended dataframe
55
    native_inds_main = np.append([0], dlev.cumsum())
56
57

    # Output index
58
59
    idx = np.zeros((native_inds_main[-1]), dtype=int)
    idx[native_inds_main[:-1]] = np.arange(len(y0))
60
    np.maximum.accumulate(idx, out=idx)
61
    native_inds_main = native_inds_main[:-1]
62
63

    # Output dataframe
64
    datacol = "adj_out" if mode == "adj" else "obs"
65
66
67
68
69
70
71
72
73
74
75
    col2process = [
        "tstep",
        "tstep_glo",
        "i",
        "j",
        "level",
        "dtstep",
        "parameter",
        "duration",
        datacol,
    ]
76
    df_main = copy.deepcopy(y0.iloc[idx])
77
78
79
80

    # Levels
    sublevels = np.meshgrid(
        list(range(transf.model.domain.nlev)),
81
        np.ones(len(y0)))[0].flatten()
Antoine Berchet's avatar
Antoine Berchet committed
82
    df_main[("metadata", "level")] = sublevels
83
84
85
86
87

    # Building the extended dataframe
    # iq1 = (np.abs(y0['level']) - np.abs((y0['level'] / 10.)
    #                                    .astype(int) * 10)) \
    #    .astype(int)
Antoine Berchet's avatar
Antoine Berchet committed
88
    iq1 = y0[("metadata", "station")]
89
    list_satIDs = iq1.unique()
Antoine Berchet's avatar
Antoine Berchet committed
90

91
    # Saving original values for later re-aggregation
Antoine Berchet's avatar
Antoine Berchet committed
92
93
    df_main.loc[:, ("metadata", "indorig")] = idx
    df_main.loc[:, ("metadata", "iq1")] = iq1.iloc[idx]
94
    for incomp in input_components:
Antoine Berchet's avatar
test    
Antoine Berchet committed
95
        out_datastore[(incomp, ref_parameter)][ddi] = \
96
            copy.deepcopy(df_main)
97
        if incomp != "concs":
Antoine Berchet's avatar
Antoine Berchet committed
98
99
            out_datastore[(incomp, ref_parameter)][ddi].loc[
                :, ("maindata", "adj_out")] = 0
Antoine Berchet's avatar
Antoine Berchet committed
100

101
102
103
    # Un-stack columns into dataframe for stratosphere
    if transf.fill_strato:
        nlev_strato = \
104
            mapper["inputs"][("stratosphere", ref_parameter)]["domain"].nlev
Antoine Berchet's avatar
Antoine Berchet committed
105

106
107
        # Number of levels to extract for satellites
        dlev = np.ones(len(y0), dtype=int) * nlev_strato
Antoine Berchet's avatar
Antoine Berchet committed
108

109
110
        # Index in the original data of the level-extended dataframe
        native_inds_strato = np.append([0], dlev.cumsum())
Antoine Berchet's avatar
Antoine Berchet committed
111

112
113
114
115
116
        # Output index
        idx = np.zeros((native_inds_strato[-1]), dtype=int)
        idx[native_inds_strato[:-1]] = np.arange(len(y0))
        np.maximum.accumulate(idx, out=idx)
        native_inds_strato = native_inds_strato[:-1]
Antoine Berchet's avatar
Antoine Berchet committed
117

118
119
        # Output dataframe
        df_strato = copy.deepcopy(y0.iloc[idx])
Antoine Berchet's avatar
Antoine Berchet committed
120

121
122
123
        # Levels
        sublevels = np.meshgrid(list(range(nlev_strato)),
                                np.ones(len(y0)))[0].flatten()
Antoine Berchet's avatar
Antoine Berchet committed
124
        df_strato[("metadata", "level")] = sublevels
Antoine Berchet's avatar
Antoine Berchet committed
125

126
        # Building the extended dataframe
Antoine Berchet's avatar
Antoine Berchet committed
127
        iq1 = y0[("metadata", "station")]
128
        list_satIDs = iq1.unique()
Antoine Berchet's avatar
Antoine Berchet committed
129

130
        # Saving original values for later re-aggregation
Antoine Berchet's avatar
Antoine Berchet committed
131
132
        df_strato.loc[:, ("metadata", "indorig")] = idx
        df_strato.loc[:, ("metadata", "iq1")] = iq1.iloc[idx]
133
        
134
        out_datastore[("stratosphere", ref_parameter)][ddi] = df_strato
135
        
136
137
    # Stop here if no adjoint to be fully computed
    # Just forward datastore to precursor transforms
138
139
140
141
142
    if onlyinit:
        return

    # Load pressure coordinates from previous run
    file_monit = ddi.strftime(
143
144
        "{}/chain/satellites/{}/monit_%Y%m%d%H%M.nc".format(
            transf.model.adj_refdir, transf.transform_id)
145
    )
146
    fwd_pressure = read_datastore(
147
        file_monit,
Antoine Berchet's avatar
Antoine Berchet committed
148
149
        col2dump=["pressure", "dp", "indorig",
                  "hlay", "airm", "sim", "pthick", "exclude_zeros"]
150
    )
Antoine Berchet's avatar
Antoine Berchet committed
151
    ref_indexes = ~fwd_pressure["metadata"].duplicated(subset=["indorig"])
Antoine Berchet's avatar
Antoine Berchet committed
152

153
154
155
156
157
    for satID in list_satIDs:
        satmask = iq1 == satID
        nobs = np.sum(satmask)

        # Getting the vector of increments
Antoine Berchet's avatar
Antoine Berchet committed
158
        obs_incr = y0.loc[satmask, ("maindata", "adj_out")]
Antoine Berchet's avatar
Antoine Berchet committed
159

160
161
162
163
164
165
        # If all increments are NaNs, just pass to next satellite
        if not np.any(obs_incr != 0.0):
            continue

        # Get target pressure
        native_ind_stack = (
Antoine Berchet's avatar
Antoine Berchet committed
166
167
                native_inds_main[satmask]
                + np.arange(transf.model.domain.nlev)[:, np.newaxis]
168
169
170
171
172
        )
        datasim = xr.Dataset(
            {
                "pressure": (
                    ["level", "index"],
Antoine Berchet's avatar
Antoine Berchet committed
173
                    np.log(fwd_pressure["metadata"]["pressure"].values[native_ind_stack]),
174
175
176
                ),
                "dp": (
                    ["level", "index"],
Antoine Berchet's avatar
Antoine Berchet committed
177
                    fwd_pressure["metadata"]["dp"].values[native_ind_stack],
178
                ),
179
                "sim": (["level", "index"],
Antoine Berchet's avatar
Antoine Berchet committed
180
                        fwd_pressure["maindata"]["sim"].values[native_ind_stack]),
181
182
            },
            coords={
183
                "index": np.arange(len(y0)),
184
185
186
                "level": np.arange(transf.model.domain.nlev),
            },
        )
187
188
189
190
191
        if transf.product == "column":
            datasim = datasim.assign(
                {
                    "airm": (
                        ["level", "index"],
Antoine Berchet's avatar
Antoine Berchet committed
192
                        fwd_pressure["metadata"]["airm"].values[native_ind_stack],
193
194
195
                    ),
                    "hlay": (
                        ["level", "index"],
Antoine Berchet's avatar
Antoine Berchet committed
196
                        fwd_pressure["metadata"]["hlay"].values[native_ind_stack],
197
198
199
                    ),
                }
            )
200
201

        # Getting averaging kernels
202
        ref_mapper = mapper["outputs"][(ref_component, ref_parameter)]
203
204
205
        files_aks = \
            list(set(list(itertools.chain(
                *ref_mapper["tracer"].input_files.values()))))
206
207
        files_aks.sort()
        info("Fetching satellite infos from files: {}".format(files_aks))
208
209

        try:
210
            colsat = ["qa0", "ak", "pavg0", "pwf", "date", "index"]
Isabelle Pison's avatar
Isabelle Pison committed
211
212
            if transf.formula == 5:
                colsat = colsat + ["dryair"]
213
            coord2dump = ["index"]
214
215
216
217
218
            all_sat_aks = []
            for file_aks in files_aks:
                sat_aks = \
                    read_datastore(file_aks,
                                   col2dump=colsat,
219
                                   coord2dump=coord2dump,
220
                                   keep_default=False,
Antoine Berchet's avatar
Antoine Berchet committed
221
222
                                   to_pandas=False).reset_index("index")

223
                # sat_aks["index"] = np.arange(sat_aks.dims["index"])
224
225
226
227
                if len(all_sat_aks) == 0:
                    all_sat_aks = sat_aks
                else:
                    all_sat_aks = xr.concat([all_sat_aks, sat_aks], "index")
Antoine Berchet's avatar
Antoine Berchet committed
228

229
            # Selecting only lines used in simulation
Antoine Berchet's avatar
Antoine Berchet committed
230
            mask = all_sat_aks["date"].isin(ref_ds["metadata"]["date"]) \
231
232
                   & all_sat_aks.index.isin(ref_ds.index)
            sat_aks = all_sat_aks.loc[{"index": mask}]
Antoine Berchet's avatar
Antoine Berchet committed
233

234
235
236
237
238
239
240
241
        except IOError:
            # Assumes total columns?
            # groups = fwd_pressure.groupby(['indorig'])
            # df['obs_incr'] = y0.ix[idx, 'obs_incr'] * fwd_pressure['dp'] \
            #                  / groups['dp'].sum().values[idx]
            continue

        if transf.pressure == "Pa":
242
            pavgs = sat_aks["pavg0"][:, ::-1].T
243
        else:
244
245
            pavgs = 100 * sat_aks["pavg0"][:, ::-1].T

246
247
248
249
250
251
252
253
254
        # Coordinates of the layer boundaries
        coords_pb = {"index": np.arange(nobs),
                     "level": np.arange(sat_aks.level_pressure.size)}
        # Coordinates of the retrieval. Can be pressure layer centers
        # ("level") or layer boundaries ("level_pressure", e.g. OCO-2).
        # Read which it is from dimensions of averaging kernel.
        ret_level_dim = sat_aks["ak"].dims[1]
        coords_ret = {"index": np.arange(nobs),
                     "level": np.arange(sat_aks[ret_level_dim].size)}
255
        dims = ("level", "index")
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        # Define pressure axis of the retrieval
        pavgs_pb = xr.DataArray(pavgs, coords_pb, dims).bfill("level")
        # Case: retrieval on layer midpoints
        if ret_level_dim=="level":
            pavgs_ret = xr.DataArray(
                np.log(0.5 * (pavgs_pb[:-1].values + pavgs_pb[1:].values)),
                coords_ret, dims)
        # Case: retrieval on layer boundaries (e.g. OCO-2)
        elif ret_level_dim=="level_pressure":
            pavgs_ret = np.log(pavgs_pb)
        else:
            raise ValueError("Unknown dimension: '{}'".format(ret_level_dim))

        # If present, read pressure weights from file
        if "pwf" in sat_aks.keys():
            dpavgs = xr.DataArray(sat_aks["pwf"][:, ::-1].T, coords_ret, dims)
        # Else, construct pressure weights
        else:
            # Case: retrieval on layer midpoints
            if ret_level_dim=="level":
                dpavgs = xr.DataArray(np.diff(-pavgs, axis=0), coords_ret, dims)
            # Case: retrieval on layer boundaries
            elif ret_level_dim=="level_pressure":
                raise NotImplementedError(
                    "Constructing pressure weights is not " + \
                    "for retrieval on layer boundaries. " + \
                    "Provide 'pwf' in observation file.")

        # Defining ak info
        #        aks = sat_aks['ak'][nblloc, ::-1][:,1:].T
        aks = xr.DataArray(sat_aks["ak"][:, ::-1].T, coords_ret, dims)
        qa0 = xr.DataArray(sat_aks["qa0"][:, ::-1].T, coords_ret, dims)
289
        
Antoine Berchet's avatar
Antoine Berchet committed
290
291
        # Exclude observations where there are all zeros in the simulation
        exclude_zeros = np.where(
Antoine Berchet's avatar
Antoine Berchet committed
292
            ~fwd_pressure["metadata"].groupby(['indorig']).min()["exclude_zeros"])[0]
Antoine Berchet's avatar
Antoine Berchet committed
293
294
295
        datasim = datasim.isel(index=exclude_zeros)
        sat_aks = sat_aks.isel(index=exclude_zeros)
        aks = aks.isel(index=exclude_zeros)
296
        pavgs_ret = pavgs_ret.isel(index=exclude_zeros)
Antoine Berchet's avatar
Antoine Berchet committed
297
298
        dpavgs = dpavgs.isel(index=exclude_zeros)
        qa0 = qa0.isel(index=exclude_zeros)
299
        
Isabelle Pison's avatar
Isabelle Pison committed
300
301
302
303
304
        # Adding dry air mole fraction if formula 5
        if transf.formula == 5:
            drycols = sat_aks["dryair"][:, ::-1].T
        else:
            drycols = qa0 * 0.0
305
            
306
307
308
309
310
311
        # Applying aks
        nbformula = transf.formula
        chosenlevel = getattr(transf, "chosenlev", 0)

        debug("nbformula: {}".format(nbformula))
        debug("chosenlev: {}".format(chosenlevel))
Antoine Berchet's avatar
Antoine Berchet committed
312

313
314
315
        # If nbformula 3, load sim_ak from forward
        if nbformula == 3:
            file_dump = ddi.strftime(
316
317
                "{}/chain/satellites/{}/sim_ak_{}_%Y%m%d%H%M.nc".format(
                    transf.model.adj_refdir, transf.transform_id, satID)
318
            )
319
            sim_ak = xr.open_dataarray(file_dump).values
320
321
        else:
            sim_ak = 0
Antoine Berchet's avatar
Antoine Berchet committed
322

323
        obs_incr = apply_ak_ad(
324
            sim_ak, dpavgs.values, aks.values,
Antoine Berchet's avatar
Antoine Berchet committed
325
326
            nbformula, qa0.values, chosenlevel, obs_incr.values[exclude_zeros],
            drycols.values
327
        )
328
        obs_incr[np.isnan(obs_incr)] = 0.
Antoine Berchet's avatar
Antoine Berchet committed
329

330
331
332
        # Correction with the pressure thickness
        # WARNING: there is an inconsistency in the number of levels
        if transf.correct_pthick:
Antoine Berchet's avatar
Antoine Berchet committed
333
            scale_pthick = fwd_pressure["metadata"]["pthick"].iloc[
334
                np.flatnonzero(ref_indexes)[satmask]]
335
            obs_incr *= scale_pthick.values
Antoine Berchet's avatar
Antoine Berchet committed
336

337
338
339
        # Adjoint of the log-pressure interpolation
        obs_incr_interp = 0.0 * datasim["pressure"].values

340
        nchunks = transf.nchunks
341
        chunks = np.linspace(0, len(datasim.index), num=nchunks + 1, dtype=int)
342
        cropstrato = transf.cropstrato
343
        for k1, k2 in zip(chunks[:-1], chunks[1:]):
344
345
346
            # Skip chunks that are too short
            if k1 == k2:
                continue
Antoine Berchet's avatar
Antoine Berchet committed
347

348
            debug("Compute chunk for satellite {}: {}-{}".format(satID, k1, k2))
Antoine Berchet's avatar
Antoine Berchet committed
349

350
            # Fetch missing values from stratosphere
351
            sim_pressure = datasim["pressure"].values[:, k1:k2]
352
353
354
            if transf.fill_strato:
                strato_mapper = mapper["inputs"][("stratosphere",
                                                  ref_parameter)]
Antoine Berchet's avatar
Antoine Berchet committed
355

356
357
358
359
360
                sigma_a = strato_mapper["domain"].sigma_a
                sigma_b = strato_mapper["domain"].sigma_b
                psurf_strato = np.exp(sim_pressure[0])
                pstrato = np.log(
                    sigma_b * psurf_strato[:, np.newaxis] + sigma_a).T
Antoine Berchet's avatar
Antoine Berchet committed
361

362
363
364
365
                # Here merge sim_pressure
                # and pstrato properly, then apply vertical_interp
                missing_levels = np.argmax(pstrato < sim_pressure[-1],
                                           axis=0)[0]
Antoine Berchet's avatar
Antoine Berchet committed
366

367
368
369
                sim_pressure = \
                    np.concatenate([sim_pressure, pstrato[missing_levels:]],
                                   axis=0)
Antoine Berchet's avatar
Antoine Berchet committed
370

371
372
            # Vertical interpolation
            xlow, xhigh, alphalow, alphahigh = vertical_interp(
373
                sim_pressure,
374
                pavgs_ret[:, k1:k2].values,
375
376
377
378
379
380
381
                cropstrato,
            )

            # Applying coefficients
            # WARNING: There might be repeated indexes in a given column
            # To deal with repeated index, np.add.at is recommended
            levmeshout = np.array(
382
                1 * [list(range(pavgs_ret.shape[0]))]
383
            ).T
384
385
            meshout = np.array(pavgs_ret.shape[0] * [list(range(k2 - k1))])
            meshin = np.array(pavgs_ret.shape[0] * [list(range(k1, k2))])
386
            
387
            tmp_obs_incr = 0.0 * sim_pressure
388
            np.add.at(
389
                tmp_obs_incr,
390
                (xlow, meshout),
391
                obs_incr[levmeshout, meshin] * alphalow,
392
393
394
            )

            np.add.at(
395
                tmp_obs_incr,
396
                (xhigh, meshout),
397
                obs_incr[levmeshout, meshin] * alphahigh,
398
            )
Antoine Berchet's avatar
Antoine Berchet committed
399

400
401
402
403
404
405
406
407
            # Deal with the stratosphere
            if transf.fill_strato:
                nstrato = strato_mapper["domain"].nlev
                nlon = strato_mapper["domain"].nlon
                nlat = strato_mapper["domain"].nlat
                incr_strato = np.zeros((nstrato, k2 - k1))
                incr_strato[missing_levels:] += \
                    tmp_obs_incr[-nstrato + missing_levels:]
Antoine Berchet's avatar
Antoine Berchet committed
408

409
410
411
412
413
414
415
416
417
418
419
420
421
                # ppb to molec/cm2 if column product
                if transf.product == "column":
                    dpstrato = \
                        np.diff(np.concatenate(
                            [psurf_strato[np.newaxis, :], np.exp(pstrato)],
                            axis=0), axis=0) / 100  # hPa
                    G = 9.88
                    dmass = np.abs(dpstrato / G)  # kg/m2
                    column = dmass * 1e3  # g/m2
                    column /= transf.molmass * 1e4  # mol/cm2
                    column *= 6.02214076 * 10 ** 23  # molec / cm2

                    incr_strato *= column
Antoine Berchet's avatar
Antoine Berchet committed
422

423
                # Fill adj_out
424
425
426
427
428
429
                native_ind_stack_strato = (
                        native_inds_strato[satmask]
                        + np.arange(nstrato)[:, np.newaxis]
                )
                out_strato = out_datastore[
                    ("stratosphere", transf.parameter[0])][ddi]
Antoine Berchet's avatar
Antoine Berchet committed
430
                out_strato[("maindata", "adj_out")].iloc[
431
432
                    native_ind_stack_strato[
                        :, exclude_zeros[k1:k2]].flatten()] = incr_strato.flatten()
433
                tmp_obs_incr = tmp_obs_incr[:-nstrato + missing_levels]
434
                
435
            obs_incr_interp[:, k1:k2] = tmp_obs_incr[:]
Antoine Berchet's avatar
Antoine Berchet committed
436

437
        # Convert CHIMERE fields to the correct unit
438
        # from ppb to molec.cm-2 if the satellite product is a column
439
        if transf.product == "column":
440
            obs_incr_interp *= datasim["hlay"].values / (
Antoine Berchet's avatar
Antoine Berchet committed
441
                    1e9 / datasim["airm"].values
442
443
444
            )

        # Applying increments to the flattened datastore
Antoine Berchet's avatar
Antoine Berchet committed
445
        df_main[("maindata", "adj_out")].iloc[
446
447
            native_ind_stack[:, exclude_zeros].flatten()] = \
            obs_incr_interp.flatten()
Antoine Berchet's avatar
Antoine Berchet committed
448

Antoine Berchet's avatar
test    
Antoine Berchet committed
449
    out_datastore[("concs", transf.parameter[0])][ddi] = df_main