Skip to content

Advanced API

For now, we list the remaining package documentation here.

Scida is a python package for reading and analyzing scientific big data.

config

Configuration handling.

combine_configs(configs, mode='overwrite_keys')

Combine multiple configurations recursively. Replacing entries in the first config with entries from the latter

Parameters:

Name Type Description Default
configs List[Dict]

The list of configurations to combine.

required
mode

The mode for combining the configurations. "overwrite_keys": overwrite keys in the first config with keys from the latter (default). "overwrite_values": overwrite values in the first config with values from the latter.

'overwrite_keys'

Returns:

Type Description
dict

The combined configuration.

Source code in src/scida/config.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def combine_configs(configs: List[Dict], mode="overwrite_keys") -> Dict:
    """
    Combine multiple configurations recursively.
    Replacing entries in the first config with entries from the latter

    Parameters
    ----------
    configs: list
        The list of configurations to combine.
    mode: str
        The mode for combining the configurations.
        "overwrite_keys": overwrite keys in the first config with keys from the latter (default).
        "overwrite_values": overwrite values in the first config with values from the latter.

    Returns
    -------
    dict
        The combined configuration.
    """
    if mode == "overwrite_values":

        def mergefunc_values(a, b):
            """merge values"""
            if b is None:
                return a
            return b  # just overwrite by latter entry

        mergefunc_keys = None
    elif mode == "overwrite_keys":

        def mergefunc_keys(a, b):
            """merge keys"""
            if b is None:
                return a
            return b

        mergefunc_values = None
    else:
        raise ValueError("Unknown mode '%s'" % mode)
    conf = configs[0]
    for c in configs[1:]:
        merge_dicts_recursively(
            conf, c, mergefunc_keys=mergefunc_keys, mergefunc_values=mergefunc_values
        )
    return conf

copy_defaultconfig(overwrite=False)

Copy the configuration example to the user's home directory.

Parameters:

Name Type Description Default
overwrite

Overwrite existing configuration file.

False

Returns:

Type Description
None
Source code in src/scida/config.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def copy_defaultconfig(overwrite=False) -> None:
    """
    Copy the configuration example to the user's home directory.
    Parameters
    ----------
    overwrite: bool
        Overwrite existing configuration file.

    Returns
    -------
    None
    """

    path_user = os.path.expanduser("~")
    path_confdir = os.path.join(path_user, ".config/scida")
    if not os.path.exists(path_confdir):
        os.makedirs(path_confdir, exist_ok=True)
    path_conf = os.path.join(path_confdir, "config.yaml")
    if os.path.exists(path_conf) and not overwrite:
        raise ValueError("Configuration file already exists at '%s'" % path_conf)
    with importlib.resources.path("scida.configfiles", "config.yaml") as fp:
        with open(fp, "r") as file:
            content = file.read()
            with open(path_conf, "w") as newfile:
                newfile.write(content)

get_config(reload=False, update_global=True)

Load the configuration from the default path.

Parameters:

Name Type Description Default
reload bool

Reload the configuration, even if it has already been loaded.

False
update_global

Update the global configuration dictionary.

True

Returns:

Type Description
dict

The configuration dictionary.

Source code in src/scida/config.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_config(reload: bool = False, update_global=True) -> dict:
    """
    Load the configuration from the default path.

    Parameters
    ----------
    reload: bool
        Reload the configuration, even if it has already been loaded.
    update_global: bool
        Update the global configuration dictionary.

    Returns
    -------
    dict
        The configuration dictionary.
    """
    global _conf
    prefix = "SCIDA_"
    envconf = {
        k.replace(prefix, "").lower(): v
        for k, v in os.environ.items()
        if k.startswith(prefix)
    }

    # in any case, we make sure that there is some config in the default path.
    path_confdir = _access_confdir()
    path_conf = os.path.join(path_confdir, "config.yaml")

    # next, we load the config from the default path, unless explicitly overridden.
    path = envconf.pop("config_path", None)
    if path is None:
        path = path_conf
    if not reload and len(_conf) > 0:
        return _conf
    config = get_config_fromfile(path)
    if config.get("copied_default", False):
        print(
            "Warning! Using default configuration. Please adjust/replace in '%s'."
            % path
        )

    config.update(**envconf)
    if update_global:
        _conf = config
    return config

get_config_fromfile(resource)

Load config from a YAML file.

Parameters:

Name Type Description Default
resource str

The name of the resource or file path.

required

Returns:

Type Description
None
Source code in src/scida/config.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def get_config_fromfile(resource: str) -> Dict:
    """
    Load config from a YAML file.
    Parameters
    ----------
    resource
        The name of the resource or file path.

    Returns
    -------
    None
    """
    if resource == "":
        raise ValueError("Config name cannot be empty.")
    # order (in descending order of priority):
    # 1. absolute path?
    path = os.path.expanduser(resource)
    if os.path.isabs(path):
        with open(path, "r") as file:
            conf = yaml.safe_load(file)
        return conf
    # 2. non-absolute path?
    # 2.1. check ~/.config/scida/
    bpath = os.path.expanduser("~/.config/scida")
    path = os.path.join(bpath, resource)
    if os.path.isfile(path):
        with open(path, "r") as file:
            conf = yaml.safe_load(file)
        return conf
    # 2.2 check scida package resources
    resource_path = "scida.configfiles"
    resource_elements = resource.split("/")
    rname = resource_elements[-1]
    if len(resource_elements) > 1:
        resource_path += "." + ".".join(resource_elements[:-1])
    with importlib.resources.path(resource_path, rname) as fp:
        with open(fp, "r") as file:
            conf = yaml.safe_load(file)
    return conf

get_config_fromfiles(paths, subconf_keys=None)

Load and merge multiple YAML config files

Parameters:

Name Type Description Default
paths List[str]

Paths to the config files.

required
subconf_keys Optional[List[str]]

The keys to the correct sub configuration within each config.

None

Returns:

Type Description
None
Source code in src/scida/config.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def get_config_fromfiles(paths: List[str], subconf_keys: Optional[List[str]] = None):
    """
    Load and merge multiple YAML config files
    Parameters
    ----------
    paths
        Paths to the config files.
    subconf_keys
        The keys to the correct sub configuration within each config.

    Returns
    -------
    None
    """
    confs = []
    for path in paths:
        confs.append(get_config_fromfile(path))
    conf = {}
    for confdict in confs:
        conf = merge_dicts_recursively(conf, confdict)
    return conf

get_simulationconfig()

Get the simulation configuration. Search regular user config file, scida simulation config file, and user's simulation config file.

Returns:

Type Description
dict

The simulation configuration.

Source code in src/scida/config.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_simulationconfig():
    """
    Get the simulation configuration.
    Search regular user config file, scida simulation config file, and user's simulation config file.

    Returns
    -------
    dict
        The simulation configuration.
    """
    conf_user = get_config()
    conf_sims = get_config_fromfile("simulations.yaml")
    conf_sims_user = _get_simulationconfig_user()

    confs_base = [conf_sims, conf_user]
    if conf_sims_user is not None:
        confs_base.append(conf_sims_user)
    # we only want to overwrite keys within "data", otherwise no merging of simkeys would take place
    confs = []
    for c in confs_base:
        if "data" in c:
            confs.append(c["data"])

    conf_sims = combine_configs(confs, mode="overwrite_keys")
    conf_sims = {"data": conf_sims}

    return conf_sims

merge_dicts_recursively(dict_a, dict_b, path=None, mergefunc_keys=None, mergefunc_values=None)

Merge two dictionaries recursively.

Parameters:

Name Type Description Default
dict_a Dict

The first dictionary.

required
dict_b Dict

The second dictionary.

required
path Optional[List]

The path to the current node.

None
mergefunc_keys Optional[callable]

The function to use for merging dict keys. If None, we recursively enter the dictionary.

None
mergefunc_values Optional[callable]

The function to use for merging dict values. If None, collisions will raise an exception.

None

Returns:

Type Description
dict
Source code in src/scida/config.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def merge_dicts_recursively(
    dict_a: Dict,
    dict_b: Dict,
    path: Optional[List] = None,
    mergefunc_keys: Optional[callable] = None,
    mergefunc_values: Optional[callable] = None,
) -> Dict:
    """
    Merge two dictionaries recursively.
    Parameters
    ----------
    dict_a
        The first dictionary.
    dict_b
        The second dictionary.
    path
        The path to the current node.
    mergefunc_keys: callable
        The function to use for merging dict keys.
        If None, we recursively enter the dictionary.
    mergefunc_values: callable
        The function to use for merging dict values.
        If None, collisions will raise an exception.

    Returns
    -------
    dict
    """
    if path is None:
        path = []
    for key in dict_b:
        if key in dict_a:
            if mergefunc_keys is not None:
                dict_a[key] = mergefunc_keys(dict_a[key], dict_b[key])
            elif isinstance(dict_a[key], dict) and isinstance(dict_b[key], dict):
                merge_dicts_recursively(
                    dict_a[key],
                    dict_b[key],
                    path + [str(key)],
                    mergefunc_keys=mergefunc_keys,
                    mergefunc_values=mergefunc_values,
                )
            elif dict_a[key] == dict_b[key]:
                pass  # same leaf value
            else:
                if mergefunc_values is not None:
                    dict_a[key] = mergefunc_values(dict_a[key], dict_b[key])
                else:
                    raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
        else:
            dict_a[key] = dict_b[key]
    return dict_a

convenience

download_and_extract(url, path, progressbar=True, overwrite=False)

Download and extract a file from a given url.

Parameters:

Name Type Description Default
url str

The url to download from.

required
path Path

The path to download to.

required
progressbar bool

Whether to show a progress bar.

True
overwrite bool

Whether to overwrite an existing file.

False

Returns:

Type Description
str

The path to the downloaded and extracted file(s).

Source code in src/scida/convenience.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def download_and_extract(
    url: str, path: pathlib.Path, progressbar: bool = True, overwrite: bool = False
):
    """
    Download and extract a file from a given url.
    Parameters
    ----------
    url: str
        The url to download from.
    path: pathlib.Path
        The path to download to.
    progressbar: bool
        Whether to show a progress bar.
    overwrite: bool
        Whether to overwrite an existing file.
    Returns
    -------
    str
        The path to the downloaded and extracted file(s).
    """
    if path.exists() and not overwrite:
        raise ValueError("Target path '%s' already exists." % path)
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        totlength = int(r.headers.get("content-length", 0))
        lread = 0
        t1 = time.time()
        with open(path, "wb") as f:
            for chunk in r.iter_content(chunk_size=2**22):  # chunks of 4MB
                t2 = time.time()
                f.write(chunk)
                lread += len(chunk)
                if progressbar:
                    rate = (lread / 2**20) / (t2 - t1)
                    sys.stdout.write(
                        "\rDownloaded %.2f/%.2f Megabytes (%.2f%%, %.2f MB/s)"
                        % (
                            lread / 2**20,
                            totlength / 2**20,
                            100.0 * lread / totlength,
                            rate,
                        )
                    )
                    sys.stdout.flush()
        sys.stdout.write("\n")
    tar = tarfile.open(path, "r:gz")
    for t in tar:
        if t.isdir():
            t.mode = int("0755", base=8)
        else:
            t.mode = int("0644", base=8)
    tar.extractall(path.parents[0])
    foldername = tar.getmembers()[0].name  # parent folder of extracted tar.gz
    tar.close()
    os.remove(path)
    return os.path.join(path.parents[0], foldername)

find_path(path, overwrite=False)

Find path to dataset.

Parameters:

Name Type Description Default
path
required
overwrite

Only for remote datasets. Whether to overwrite an existing download.

False

Returns:

Type Description
str
Source code in src/scida/convenience.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def find_path(path, overwrite=False) -> str:
    """
    Find path to dataset.

    Parameters
    ----------
    path: str
    overwrite: bool
        Only for remote datasets. Whether to overwrite an existing download.

    Returns
    -------
    str

    """
    config = get_config()
    path = os.path.expanduser(path)
    if os.path.exists(path):
        # datasets on disk
        pass
    elif len(path.split(":")) > 1:
        # check different alternative backends
        databackend = path.split("://")[0]
        dataname = path.split("://")[1]
        if databackend in ["http", "https"]:
            # dataset on the internet
            savepath = config.get("download_path", None)
            if savepath is None:
                print(
                    "Have not specified 'download_path' in config. Using 'cache_path' instead."
                )
                savepath = config.get("cache_path")
            savepath = os.path.expanduser(savepath)
            savepath = pathlib.Path(savepath)
            savepath.mkdir(parents=True, exist_ok=True)
            urlhash = str(
                int(hashlib.sha256(path.encode("utf-8")).hexdigest(), 16) % 10**8
            )
            savepath = savepath / ("download" + urlhash)
            filename = "archive.tar.gz"
            if not savepath.exists():
                os.makedirs(savepath, exist_ok=True)
            elif overwrite:
                # delete all files in folder
                for f in os.listdir(savepath):
                    fp = os.path.join(savepath, f)
                    if os.path.isfile(fp):
                        os.unlink(fp)
                    else:
                        shutil.rmtree(fp)
            foldercontent = [f for f in savepath.glob("*")]
            if len(foldercontent) == 0:
                savepath = savepath / filename
                extractpath = download_and_extract(
                    path, savepath, progressbar=True, overwrite=overwrite
                )
            else:
                extractpath = savepath
            extractpath = pathlib.Path(extractpath)

            # count folders in savepath
            nfolders = len([f for f in extractpath.glob("*") if f.is_dir()])
            nobjects = len([f for f in extractpath.glob("*") if f.is_dir()])
            if nobjects == 1 and nfolders == 1:
                extractpath = (
                    extractpath / [f for f in extractpath.glob("*") if f.is_dir()][0]
                )
            path = extractpath
        elif databackend == "testdata":
            path = get_testdata(dataname)
        else:
            # potentially custom dataset.
            resources = config.get("resources", {})
            if databackend not in resources:
                raise ValueError("Unknown resource '%s'" % databackend)
            r = resources[databackend]
            if dataname not in r:
                raise ValueError(
                    "Unknown dataset '%s' in resource '%s'" % (dataname, databackend)
                )
            path = os.path.expanduser(r[dataname]["path"])
    else:
        found = False
        if "datafolders" in config:
            for folder in config["datafolders"]:
                folder = os.path.expanduser(folder)
                if os.path.exists(os.path.join(folder, path)):
                    path = os.path.join(folder, path)
                    found = True
                    break
        if not found:
            raise ValueError("Specified path '%s' unknown." % path)
    return path

get_dataset(name=None, props=None)

Get dataset by name or properties.

Parameters:

Name Type Description Default
name

Name or alias of dataset.

None
props

Properties to match.

None

Returns:

Name Type Description
str

Dataset name.

Source code in src/scida/convenience.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def get_dataset(name=None, props=None):
    """
    Get dataset by name or properties.
    Parameters
    ----------
    name: Optional[str]
        Name or alias of dataset.
    props: Optional[dict]
        Properties to match.

    Returns
    -------
    str:
        Dataset name.

    """
    dnames = get_dataset_candidates(name=name, props=props)
    if len(dnames) > 1:
        raise ValueError("Too many dataset candidates.")
    elif len(dnames) == 0:
        raise ValueError("No dataset candidate found.")
    return dnames[0]

get_dataset_by_name(name)

Get dataset name from alias or name found in the configuration files.

Parameters:

Name Type Description Default
name str

Name or alias of dataset.

required

Returns:

Type Description
str
Source code in src/scida/convenience.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def get_dataset_by_name(name: str) -> Optional[str]:
    """
    Get dataset name from alias or name found in the configuration files.

    Parameters
    ----------
    name: str
        Name or alias of dataset.

    Returns
    -------
    str
    """
    dname = None
    c = get_config()
    if "datasets" not in c:
        return dname
    datasets = copy.deepcopy(c["datasets"])
    if name in datasets:
        dname = name
    else:
        # could still be an alias
        for k, v in datasets.items():
            if "aliases" not in v:
                continue
            if name in v["aliases"]:
                dname = k
                break
    return dname

get_dataset_candidates(name=None, props=None)

Get dataset candidates by name or properties.

Parameters:

Name Type Description Default
name

Name or alias of dataset.

None
props

Properties to match.

None

Returns:

Type Description
list[str]:

List of candidate dataset names.

Source code in src/scida/convenience.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def get_dataset_candidates(name=None, props=None):
    """
    Get dataset candidates by name or properties.

    Parameters
    ----------
    name: Optional[str]
        Name or alias of dataset.
    props: Optional[dict]
        Properties to match.

    Returns
    -------
    list[str]:
        List of candidate dataset names.

    """
    if name is not None:
        dnames = [get_dataset_by_name(name)]
        return dnames
    if props is not None:
        dnames = get_datasets_by_props(**props)
        return dnames
    raise ValueError("Need to specify name or properties.")

get_datasets_by_props(**kwargs)

Get dataset names by properties.

Parameters:

Name Type Description Default
kwargs

Properties to match.

{}

Returns:

Type Description
list[str]:

List of dataset names.

Source code in src/scida/convenience.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def get_datasets_by_props(**kwargs):
    """
    Get dataset names by properties.

    Parameters
    ----------
    kwargs: dict
        Properties to match.

    Returns
    -------
    list[str]:
        List of dataset names.
    """
    dnames = []
    c = get_config()
    if "datasets" not in c:
        return dnames
    datasets = copy.deepcopy(c["datasets"])
    for k, v in datasets.items():
        props = v.get("properties", {})
        match = True
        for pk, pv in kwargs.items():
            if pk not in props:
                match = False
                break
            if props[pk] != pv:
                match = False
                break
        if match:
            dnames.append(k)
    return dnames

get_testdata(name)

Get path to test data identifier.

Parameters:

Name Type Description Default
name str

Name of test data.

required

Returns:

Type Description
str
Source code in src/scida/convenience.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def get_testdata(name: str) -> str:
    """
    Get path to test data identifier.

    Parameters
    ----------
    name: str
        Name of test data.
    Returns
    -------
    str
    """
    config = get_config()
    tdpath: Optional[str] = config.get("testdata_path", None)
    if tdpath is None:
        raise ValueError("Test data directory not specified in configuration")
    if not os.path.isdir(tdpath):
        raise ValueError("Invalid test data path")
    res = {f: os.path.join(tdpath, f) for f in os.listdir(tdpath)}
    if name not in res.keys():
        raise ValueError("Specified test data not available.")
    return res[name]

load(path, units=True, unitfile='', overwrite=False, force_class=None, **kwargs)

Load a dataset or dataset series from a given path. This function will automatically determine the best-matching class to use and return the initialized instance.

Parameters:

Name Type Description Default
path str

Path to dataset or dataset series. Usually the base folder containing all files of a given dataset/series.

required
units Union[bool, str]

Whether to load units.

True
unitfile str

Can explicitly pass path to a unitfile to use.

''
overwrite bool

Whether to overwrite an existing cache.

False
force_class Optional[object]

Force a specific class to be used.

None
kwargs

Additional keyword arguments to pass to the class.

{}

Returns:

Type Description
Union[Dataset, DatasetSeries]:

Initialized dataset or dataset series.

Source code in src/scida/convenience.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def load(
    path: str,
    units: Union[bool, str] = True,
    unitfile: str = "",
    overwrite: bool = False,
    force_class: Optional[object] = None,
    **kwargs
):
    """
    Load a dataset or dataset series from a given path.
    This function will automatically determine the best-matching
    class to use and return the initialized instance.

    Parameters
    ----------
    path: str
        Path to dataset or dataset series. Usually the base folder containing all files of a given dataset/series.
    units: bool
        Whether to load units.
    unitfile: str
        Can explicitly pass path to a unitfile to use.
    overwrite: bool
        Whether to overwrite an existing cache.
    force_class: object
        Force a specific class to be used.
    kwargs: dict
        Additional keyword arguments to pass to the class.

    Returns
    -------
    Union[Dataset, DatasetSeries]:
        Initialized dataset or dataset series.
    """

    path = find_path(path, overwrite=overwrite)

    if "catalog" in kwargs:
        c = kwargs["catalog"]
        query_path = True
        query_path &= c is not None
        query_path &= not isinstance(c, bool)
        query_path &= not isinstance(c, list)
        query_path &= c != "none"
        if query_path:
            kwargs["catalog"] = find_path(c, overwrite=overwrite)

    # determine dataset class
    reg = dict()
    reg.update(**dataset_type_registry)
    reg.update(**dataseries_type_registry)

    path = os.path.realpath(path)
    cls = _determine_type(path, **kwargs)[1][0]

    msg = "Dataset is identified as '%s' via _determine_type." % cls
    log.debug(msg)

    # any identifying metadata?
    classtype = "dataset"
    if issubclass(cls, DatasetSeries):
        classtype = "series"
    cls_simconf = _determine_type_from_simconfig(path, classtype=classtype, reg=reg)

    if cls_simconf and not issubclass(cls, cls_simconf):
        oldcls = cls
        cls = cls_simconf
        if oldcls != cls:
            msg = "Dataset is identified as '%s' via the simulation config replacing prior candidate '%s'."
            log.debug(msg % (cls, oldcls))
        else:
            msg = "Dataset is identified as '%s' via the simulation config, identical to prior candidate."
            log.debug(msg % cls)

    if force_class is not None:
        cls = force_class

    # determine additional mixins not set by class
    mixins = []
    if hasattr(cls, "_unitfile"):
        unitfile = cls._unitfile
    if unitfile:
        if not units:
            units = True
        kwargs["unitfile"] = unitfile

    if units:
        mixins.append(UnitMixin)
        kwargs["units"] = units

    msg = "Inconsistent overwrite_cache, please only use 'overwrite' in load()."
    assert kwargs.get("overwrite_cache", overwrite) == overwrite, msg
    kwargs["overwrite_cache"] = overwrite

    # we append since unit mixin is added outside of this func right now
    metadata_raw = dict()
    if classtype == "dataset":
        metadata_raw = load_metadata(path, fileprefix=None)
    other_mixins = _determine_mixins(path=path, metadata_raw=metadata_raw)
    mixins += other_mixins

    log.debug("Adding mixins '%s' to dataset." % mixins)
    if hasattr(cls, "_mixins"):
        cls_mixins = cls._mixins
        for m in cls_mixins:
            # remove duplicates
            if m in mixins:
                mixins.remove(m)

    instance = cls(path, mixins=mixins, **kwargs)
    return instance

customs

arepo

MTNG

dataset

Support for MTNG-Arepo datasets, see https://www.mtng-project.org/

MTNGArepoCatalog

Bases: ArepoCatalog

A dataset representing a MTNG-Arepo catalog.

Source code in src/scida/customs/arepo/MTNG/dataset.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class MTNGArepoCatalog(ArepoCatalog):
    """
    A dataset representing a MTNG-Arepo catalog.
    """

    _fileprefix = "fof_subhalo_tab"

    def __init__(self, *args, **kwargs):
        """
        Initialize an MTNGArepoCatalog object.

        Parameters
        ----------
        args: list
        kwargs: dict
        """
        kwargs["iscatalog"] = True
        if "fileprefix" not in kwargs:
            kwargs["fileprefix"] = "fof_subhalo_tab"
        kwargs["choose_prefix"] = True
        super().__init__(*args, **kwargs)

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path as a candidate for the MTNG-Arepo catalog class.

        Parameters
        ----------
        path: str
            Path to validate.
        args: list
        kwargs: dict

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this dataset class.
        """
        tkwargs = dict(fileprefix=cls._fileprefix)
        tkwargs.update(**kwargs)
        valid = super().validate_path(path, *args, **tkwargs)
        if valid == CandidateStatus.NO:
            return valid
        metadata_raw = load_metadata(path, **tkwargs)
        if "/Config" not in metadata_raw:
            return CandidateStatus.NO
        if "MTNG" not in metadata_raw["/Config"]:
            return CandidateStatus.NO
        return CandidateStatus.YES
__init__(*args, **kwargs)

Initialize an MTNGArepoCatalog object.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/customs/arepo/MTNG/dataset.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __init__(self, *args, **kwargs):
    """
    Initialize an MTNGArepoCatalog object.

    Parameters
    ----------
    args: list
    kwargs: dict
    """
    kwargs["iscatalog"] = True
    if "fileprefix" not in kwargs:
        kwargs["fileprefix"] = "fof_subhalo_tab"
    kwargs["choose_prefix"] = True
    super().__init__(*args, **kwargs)
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for the MTNG-Arepo catalog class.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Path to validate.

required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this dataset class.

Source code in src/scida/customs/arepo/MTNG/dataset.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path as a candidate for the MTNG-Arepo catalog class.

    Parameters
    ----------
    path: str
        Path to validate.
    args: list
    kwargs: dict

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this dataset class.
    """
    tkwargs = dict(fileprefix=cls._fileprefix)
    tkwargs.update(**kwargs)
    valid = super().validate_path(path, *args, **tkwargs)
    if valid == CandidateStatus.NO:
        return valid
    metadata_raw = load_metadata(path, **tkwargs)
    if "/Config" not in metadata_raw:
        return CandidateStatus.NO
    if "MTNG" not in metadata_raw["/Config"]:
        return CandidateStatus.NO
    return CandidateStatus.YES
MTNGArepoSnapshot

Bases: ArepoSnapshot

MTNGArepoSnapshot is a snapshot class for the MTNG project.

Source code in src/scida/customs/arepo/MTNG/dataset.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class MTNGArepoSnapshot(ArepoSnapshot):
    """
    MTNGArepoSnapshot is a snapshot class for the MTNG project.
    """

    _fileprefix_catalog = "fof_subhalo_tab"
    _fileprefix = "snapshot_"  # underscore is important!

    def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
        """
        Initialize an MTNGArepoSnapshot object.

        Parameters
        ----------
        path: str
            Path to the snapshot folder, should contain "output" folder.
        chunksize: int
            Chunksize for the data.
        catalog: str
            Explicitly state catalog path to use.
        kwargs:
            Additional keyword arguments.
        """
        tkwargs = dict(
            fileprefix=self._fileprefix,
            fileprefix_catalog=self._fileprefix_catalog,
            catalog_cls=MTNGArepoCatalog,
        )
        tkwargs.update(**kwargs)

        # in MTNG, we have two kinds of snapshots:
        # 1. regular (prefix: "snapshot_"), contains all particle types
        # 2. mostbound (prefix: "snapshot-prevmostboundonly_"), only contains DM particles
        # Most snapshots are of type 2, but some selected snapshots have type 1 and type 2.

        # attempt to load regular snapshot
        super().__init__(path, chunksize=chunksize, catalog=catalog, **tkwargs)
        # need to add mtng unit peculiarities
        # later unit file takes precedence
        self._defaultunitfiles += ["units/mtng.yaml"]

        if tkwargs["fileprefix"] == "snapshot-prevmostboundonly_":
            # this is a mostbound snapshot, so we are done
            return

        # next, attempt to load mostbound snapshot. This is done by loading into sub-object.
        self.mostbound = None
        tkwargs.update(fileprefix="snapshot-prevmostboundonly_", catalog="none")
        self.mostbound = MTNGArepoSnapshot(path, chunksize=chunksize, **tkwargs)
        # hacky: remove unused containers from mostbound snapshot
        for k in [
            "PartType0",
            "PartType2",
            "PartType3",
            "PartType4",
            "PartType5",
            "Group",
            "Subhalo",
        ]:
            if k in self.mostbound.data:
                del self.mostbound.data[k]
        self.merge_data(self.mostbound, fieldname_suffix="_mostbound")

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path as a candidate for the MTNG-Arepo snapshot class.

        Parameters
        ----------
        path:
            Path to validate.
        args:  list
        kwargs: dict

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this dataset class.

        """
        tkwargs = dict(
            fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
        )
        tkwargs.update(**kwargs)
        try:
            valid = super().validate_path(path, *args, **tkwargs)
        except ValueError:
            valid = CandidateStatus.NO
            # might raise ValueError in case of partial snap

        if valid == CandidateStatus.NO:
            # check for partial snap
            tkwargs.update(fileprefix="snapshot-prevmostboundonly_")
            try:
                valid = super().validate_path(path, *args, **tkwargs)
            except ValueError:
                valid = CandidateStatus.NO

        if valid == CandidateStatus.NO:
            return valid
        metadata_raw = load_metadata(path, **tkwargs)
        if "/Config" not in metadata_raw:
            return CandidateStatus.NO
        if "MTNG" not in metadata_raw["/Config"]:
            return CandidateStatus.NO
        if "Ngroups_Total" in metadata_raw["/Header"]:
            return CandidateStatus.NO  # this is a catalog
        return CandidateStatus.YES
__init__(path, chunksize='auto', catalog=None, **kwargs)

Initialize an MTNGArepoSnapshot object.

Parameters:

Name Type Description Default
path

Path to the snapshot folder, should contain "output" folder.

required
chunksize

Chunksize for the data.

'auto'
catalog

Explicitly state catalog path to use.

None
kwargs

Additional keyword arguments.

{}
Source code in src/scida/customs/arepo/MTNG/dataset.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
    """
    Initialize an MTNGArepoSnapshot object.

    Parameters
    ----------
    path: str
        Path to the snapshot folder, should contain "output" folder.
    chunksize: int
        Chunksize for the data.
    catalog: str
        Explicitly state catalog path to use.
    kwargs:
        Additional keyword arguments.
    """
    tkwargs = dict(
        fileprefix=self._fileprefix,
        fileprefix_catalog=self._fileprefix_catalog,
        catalog_cls=MTNGArepoCatalog,
    )
    tkwargs.update(**kwargs)

    # in MTNG, we have two kinds of snapshots:
    # 1. regular (prefix: "snapshot_"), contains all particle types
    # 2. mostbound (prefix: "snapshot-prevmostboundonly_"), only contains DM particles
    # Most snapshots are of type 2, but some selected snapshots have type 1 and type 2.

    # attempt to load regular snapshot
    super().__init__(path, chunksize=chunksize, catalog=catalog, **tkwargs)
    # need to add mtng unit peculiarities
    # later unit file takes precedence
    self._defaultunitfiles += ["units/mtng.yaml"]

    if tkwargs["fileprefix"] == "snapshot-prevmostboundonly_":
        # this is a mostbound snapshot, so we are done
        return

    # next, attempt to load mostbound snapshot. This is done by loading into sub-object.
    self.mostbound = None
    tkwargs.update(fileprefix="snapshot-prevmostboundonly_", catalog="none")
    self.mostbound = MTNGArepoSnapshot(path, chunksize=chunksize, **tkwargs)
    # hacky: remove unused containers from mostbound snapshot
    for k in [
        "PartType0",
        "PartType2",
        "PartType3",
        "PartType4",
        "PartType5",
        "Group",
        "Subhalo",
    ]:
        if k in self.mostbound.data:
            del self.mostbound.data[k]
    self.merge_data(self.mostbound, fieldname_suffix="_mostbound")
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for the MTNG-Arepo snapshot class.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Path to validate.

required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this dataset class.

Source code in src/scida/customs/arepo/MTNG/dataset.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path as a candidate for the MTNG-Arepo snapshot class.

    Parameters
    ----------
    path:
        Path to validate.
    args:  list
    kwargs: dict

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this dataset class.

    """
    tkwargs = dict(
        fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
    )
    tkwargs.update(**kwargs)
    try:
        valid = super().validate_path(path, *args, **tkwargs)
    except ValueError:
        valid = CandidateStatus.NO
        # might raise ValueError in case of partial snap

    if valid == CandidateStatus.NO:
        # check for partial snap
        tkwargs.update(fileprefix="snapshot-prevmostboundonly_")
        try:
            valid = super().validate_path(path, *args, **tkwargs)
        except ValueError:
            valid = CandidateStatus.NO

    if valid == CandidateStatus.NO:
        return valid
    metadata_raw = load_metadata(path, **tkwargs)
    if "/Config" not in metadata_raw:
        return CandidateStatus.NO
    if "MTNG" not in metadata_raw["/Config"]:
        return CandidateStatus.NO
    if "Ngroups_Total" in metadata_raw["/Header"]:
        return CandidateStatus.NO  # this is a catalog
    return CandidateStatus.YES

TNGcluster

dataset
TNGClusterSelector

Bases: Selector

Selector for TNGClusterSnapshot. Can select for zoomID, which selects a given zoom target. Can specify withfuzz=True to include the "fuzz" particles for a given zoom target. Can specify onlyfuzz=True to only return the "fuzz" particles for a given zoom target.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class TNGClusterSelector(Selector):
    """
    Selector for TNGClusterSnapshot.  Can select for zoomID, which selects a given zoom target.
    Can specify withfuzz=True to include the "fuzz" particles for a given zoom target.
    Can specify onlyfuzz=True to only return the "fuzz" particles for a given zoom target.
    """

    def __init__(self) -> None:
        """
        Initialize the selector.
        """
        super().__init__()
        self.keys = ["zoomID", "withfuzz", "onlyfuzz"]

    def prepare(self, *args, **kwargs) -> None:
        """
        Prepare the selector.

        Parameters
        ----------
        args: list
        kwargs: dict

        Returns
        -------
        None
        """
        snap: TNGClusterSnapshot = args[0]
        zoom_id = kwargs.get("zoomID", None)
        fuzz = kwargs.get("withfuzz", None)
        onlyfuzz = kwargs.get("onlylfuzz", None)
        if zoom_id is None:
            return
        if zoom_id < 0 or zoom_id > (snap.ntargets - 1):
            raise ValueError("zoomID must be in range 0-%i" % (snap.ntargets - 1))

        for p in self.data_backup:
            if p.startswith("PartType"):
                key = "particles"
            elif p == "Group":
                key = "groups"
            elif p == "Subhalo":
                key = "subgroups"
            else:
                continue
            lengths = snap.lengths_zoom[key][zoom_id]
            offsets = snap.offsets_zoom[key][zoom_id]
            length_fuzz = None
            offset_fuzz = None

            if fuzz and key == "particles":  # fuzz files only for particles
                lengths_fuzz = snap.lengths_zoom[key][zoom_id + snap.ntargets]
                offsets_fuzz = snap.offsets_zoom[key][zoom_id + snap.ntargets]

                splt = p.split("PartType")
                pnum = int(splt[1])
                offset_fuzz = offsets_fuzz[pnum]
                length_fuzz = lengths_fuzz[pnum]

            if key == "particles":
                splt = p.split("PartType")
                pnum = int(splt[1])
                offset = offsets[pnum]
                length = lengths[pnum]
            else:
                offset = offsets
                length = lengths

            def get_slicedarr(
                v, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
            ):
                """
                Get a sliced dask array for a given (length, offset) and (length_fuzz, offset_fuzz).

                Parameters
                ----------
                v: da.Array
                    The array to slice.
                offset: int
                length: int
                offset_fuzz: int
                length_fuzz: int
                key: str
                    ?
                fuzz: bool

                Returns
                -------
                da.Array
                    The sliced array.
                """
                arr = v[offset : offset + length]
                if offset_fuzz is not None:
                    arr_fuzz = v[offset_fuzz : offset_fuzz + length_fuzz]
                    if onlyfuzz:
                        arr = arr_fuzz
                    else:
                        arr = np.concatenate([arr, arr_fuzz])
                return arr

            def get_slicedfunc(
                func, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
            ):
                """
                Slice a functions output for a given (length, offset) and (length_fuzz, offset_fuzz).

                Parameters
                ----------
                func: callable
                offset: int
                length: int
                offset_fuzz: int
                length_fuzz: int
                key: str
                fuzz: bool

                Returns
                -------
                callable
                    The sliced function.
                """

                def newfunc(
                    arrs, o=offset, ln=length, of=offset_fuzz, lnf=length_fuzz, **kwargs
                ):
                    arr_all = func(arrs, **kwargs)
                    arr = arr_all[o : o + ln]
                    if of is None:
                        return arr
                    arr_fuzz = arr_all[of : of + lnf]
                    if onlyfuzz:
                        return arr_fuzz
                    else:
                        return np.concatenate([arr, arr_fuzz])

                return newfunc

            # need to evaluate without recipes first
            for k, v in self.data_backup[p].items(withrecipes=False):
                self.data[p][k] = get_slicedarr(
                    v, offset, length, offset_fuzz, length_fuzz, key, fuzz
                )

            for k, v in self.data_backup[p].items(
                withfields=False, withrecipes=True, evaluate=False
            ):
                if not isinstance(v, FieldRecipe):
                    continue  # already evaluated, no need to port recipe (?)
                rcp: FieldRecipe = v
                func = get_slicedfunc(
                    v.func, offset, length, offset_fuzz, length_fuzz, key, fuzz
                )
                newrcp = DerivedFieldRecipe(rcp.name, func)
                newrcp.type = rcp.type
                newrcp.description = rcp.description
                newrcp.units = rcp.units
                self.data[p][k] = newrcp
        snap.data = self.data
__init__()

Initialize the selector.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
20
21
22
23
24
25
def __init__(self) -> None:
    """
    Initialize the selector.
    """
    super().__init__()
    self.keys = ["zoomID", "withfuzz", "onlyfuzz"]
prepare(*args, **kwargs)

Prepare the selector.

Parameters:

Name Type Description Default
args
()
kwargs
{}

Returns:

Type Description
None
Source code in src/scida/customs/arepo/TNGcluster/dataset.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def prepare(self, *args, **kwargs) -> None:
    """
    Prepare the selector.

    Parameters
    ----------
    args: list
    kwargs: dict

    Returns
    -------
    None
    """
    snap: TNGClusterSnapshot = args[0]
    zoom_id = kwargs.get("zoomID", None)
    fuzz = kwargs.get("withfuzz", None)
    onlyfuzz = kwargs.get("onlylfuzz", None)
    if zoom_id is None:
        return
    if zoom_id < 0 or zoom_id > (snap.ntargets - 1):
        raise ValueError("zoomID must be in range 0-%i" % (snap.ntargets - 1))

    for p in self.data_backup:
        if p.startswith("PartType"):
            key = "particles"
        elif p == "Group":
            key = "groups"
        elif p == "Subhalo":
            key = "subgroups"
        else:
            continue
        lengths = snap.lengths_zoom[key][zoom_id]
        offsets = snap.offsets_zoom[key][zoom_id]
        length_fuzz = None
        offset_fuzz = None

        if fuzz and key == "particles":  # fuzz files only for particles
            lengths_fuzz = snap.lengths_zoom[key][zoom_id + snap.ntargets]
            offsets_fuzz = snap.offsets_zoom[key][zoom_id + snap.ntargets]

            splt = p.split("PartType")
            pnum = int(splt[1])
            offset_fuzz = offsets_fuzz[pnum]
            length_fuzz = lengths_fuzz[pnum]

        if key == "particles":
            splt = p.split("PartType")
            pnum = int(splt[1])
            offset = offsets[pnum]
            length = lengths[pnum]
        else:
            offset = offsets
            length = lengths

        def get_slicedarr(
            v, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
        ):
            """
            Get a sliced dask array for a given (length, offset) and (length_fuzz, offset_fuzz).

            Parameters
            ----------
            v: da.Array
                The array to slice.
            offset: int
            length: int
            offset_fuzz: int
            length_fuzz: int
            key: str
                ?
            fuzz: bool

            Returns
            -------
            da.Array
                The sliced array.
            """
            arr = v[offset : offset + length]
            if offset_fuzz is not None:
                arr_fuzz = v[offset_fuzz : offset_fuzz + length_fuzz]
                if onlyfuzz:
                    arr = arr_fuzz
                else:
                    arr = np.concatenate([arr, arr_fuzz])
            return arr

        def get_slicedfunc(
            func, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
        ):
            """
            Slice a functions output for a given (length, offset) and (length_fuzz, offset_fuzz).

            Parameters
            ----------
            func: callable
            offset: int
            length: int
            offset_fuzz: int
            length_fuzz: int
            key: str
            fuzz: bool

            Returns
            -------
            callable
                The sliced function.
            """

            def newfunc(
                arrs, o=offset, ln=length, of=offset_fuzz, lnf=length_fuzz, **kwargs
            ):
                arr_all = func(arrs, **kwargs)
                arr = arr_all[o : o + ln]
                if of is None:
                    return arr
                arr_fuzz = arr_all[of : of + lnf]
                if onlyfuzz:
                    return arr_fuzz
                else:
                    return np.concatenate([arr, arr_fuzz])

            return newfunc

        # need to evaluate without recipes first
        for k, v in self.data_backup[p].items(withrecipes=False):
            self.data[p][k] = get_slicedarr(
                v, offset, length, offset_fuzz, length_fuzz, key, fuzz
            )

        for k, v in self.data_backup[p].items(
            withfields=False, withrecipes=True, evaluate=False
        ):
            if not isinstance(v, FieldRecipe):
                continue  # already evaluated, no need to port recipe (?)
            rcp: FieldRecipe = v
            func = get_slicedfunc(
                v.func, offset, length, offset_fuzz, length_fuzz, key, fuzz
            )
            newrcp = DerivedFieldRecipe(rcp.name, func)
            newrcp.type = rcp.type
            newrcp.description = rcp.description
            newrcp.units = rcp.units
            self.data[p][k] = newrcp
    snap.data = self.data
TNGClusterSnapshot

Bases: ArepoSnapshot

Dataset class for the TNG-Cluster simulation.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class TNGClusterSnapshot(ArepoSnapshot):
    """
    Dataset class for the TNG-Cluster simulation.
    """

    _fileprefix_catalog = "fof_subhalo_tab_"
    _fileprefix = "snap_"
    ntargets = 352

    def __init__(self, *args, **kwargs):
        """
        Initialize a TNGClusterSnapshot object.

        Parameters
        ----------
        args
        kwargs
        """
        super().__init__(*args, **kwargs)

        # we can get the offsets from the header as our load routine concats the various values from different files
        # these offsets can be used to select a given halo.
        # see: https://tng-project.org/w/index.php/TNG-Cluster
        # each zoom-target has two entries i and i+N, where N=352 is the number of zoom targets.
        # the first file contains the particles that were contained in the original low-res run
        # the second file contains all other remaining particles in a given zoom target
        def len_to_offsets(lengths):
            """
            From the particle count (field length), get the offset of all zoom targets.

            Parameters
            ----------
            lengths

            Returns
            -------
            np.ndarray
            """
            lengths = np.array(lengths)
            shp = len(lengths.shape)
            n = lengths.shape[-1]
            if shp == 1:
                res = np.concatenate(
                    [np.zeros(1, dtype=np.int64), np.cumsum(lengths.astype(np.int64))]
                )[:-1]
            else:
                res = np.cumsum(
                    np.vstack([np.zeros(n, dtype=np.int64), lengths.astype(np.int64)]),
                    axis=0,
                )[:-1]
            return res

        self.lengths_zoom = dict(particles=self.header["NumPart_ThisFile"])
        self.offsets_zoom = dict(
            particles=len_to_offsets(self.lengths_zoom["particles"])
        )

        if hasattr(self, "catalog") and self.catalog is not None:
            self.lengths_zoom["groups"] = self.catalog.header["Ngroups_ThisFile"]
            self.offsets_zoom["groups"] = len_to_offsets(self.lengths_zoom["groups"])
            self.lengths_zoom["subgroups"] = self.catalog.header["Nsubgroups_ThisFile"]
            self.offsets_zoom["subgroups"] = len_to_offsets(
                self.lengths_zoom["subgroups"]
            )

    @TNGClusterSelector()
    def return_data(self) -> FieldContainer:
        """
        Return the data object.

        Returns
        -------
        FieldContainer
            The data object.
        """
        return super().return_data()

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path as a candidate for TNG-Cluster snapshot class.

        Parameters
        ----------
        path: str
            Path to validate.
        args: list
        kwargs: dict

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this simulation class.
        """

        tkwargs = dict(
            fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
        )
        tkwargs.update(**kwargs)
        valid = super().validate_path(path, *args, **tkwargs)
        if valid == CandidateStatus.NO:
            return valid
        metadata_raw = load_metadata(path, **tkwargs)

        matchingattrs = True

        parameters = metadata_raw["/Parameters"]
        if "InitCondFile" in parameters:
            matchingattrs &= parameters["InitCondFile"] == "various"
        else:
            return CandidateStatus.NO
        header = metadata_raw["/Header"]
        matchingattrs &= header["BoxSize"] == 680000.0
        matchingattrs &= header["NumPart_Total"][1] == 1944529344
        matchingattrs &= header["NumPart_Total"][2] == 586952200

        if matchingattrs:
            valid = CandidateStatus.YES
        else:
            valid = CandidateStatus.NO
        return valid
__init__(*args, **kwargs)

Initialize a TNGClusterSnapshot object.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/customs/arepo/TNGcluster/dataset.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def __init__(self, *args, **kwargs):
    """
    Initialize a TNGClusterSnapshot object.

    Parameters
    ----------
    args
    kwargs
    """
    super().__init__(*args, **kwargs)

    # we can get the offsets from the header as our load routine concats the various values from different files
    # these offsets can be used to select a given halo.
    # see: https://tng-project.org/w/index.php/TNG-Cluster
    # each zoom-target has two entries i and i+N, where N=352 is the number of zoom targets.
    # the first file contains the particles that were contained in the original low-res run
    # the second file contains all other remaining particles in a given zoom target
    def len_to_offsets(lengths):
        """
        From the particle count (field length), get the offset of all zoom targets.

        Parameters
        ----------
        lengths

        Returns
        -------
        np.ndarray
        """
        lengths = np.array(lengths)
        shp = len(lengths.shape)
        n = lengths.shape[-1]
        if shp == 1:
            res = np.concatenate(
                [np.zeros(1, dtype=np.int64), np.cumsum(lengths.astype(np.int64))]
            )[:-1]
        else:
            res = np.cumsum(
                np.vstack([np.zeros(n, dtype=np.int64), lengths.astype(np.int64)]),
                axis=0,
            )[:-1]
        return res

    self.lengths_zoom = dict(particles=self.header["NumPart_ThisFile"])
    self.offsets_zoom = dict(
        particles=len_to_offsets(self.lengths_zoom["particles"])
    )

    if hasattr(self, "catalog") and self.catalog is not None:
        self.lengths_zoom["groups"] = self.catalog.header["Ngroups_ThisFile"]
        self.offsets_zoom["groups"] = len_to_offsets(self.lengths_zoom["groups"])
        self.lengths_zoom["subgroups"] = self.catalog.header["Nsubgroups_ThisFile"]
        self.offsets_zoom["subgroups"] = len_to_offsets(
            self.lengths_zoom["subgroups"]
        )
return_data()

Return the data object.

Returns:

Type Description
FieldContainer

The data object.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
238
239
240
241
242
243
244
245
246
247
248
@TNGClusterSelector()
def return_data(self) -> FieldContainer:
    """
    Return the data object.

    Returns
    -------
    FieldContainer
        The data object.
    """
    return super().return_data()
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for TNG-Cluster snapshot class.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Path to validate.

required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this simulation class.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path as a candidate for TNG-Cluster snapshot class.

    Parameters
    ----------
    path: str
        Path to validate.
    args: list
    kwargs: dict

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this simulation class.
    """

    tkwargs = dict(
        fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
    )
    tkwargs.update(**kwargs)
    valid = super().validate_path(path, *args, **tkwargs)
    if valid == CandidateStatus.NO:
        return valid
    metadata_raw = load_metadata(path, **tkwargs)

    matchingattrs = True

    parameters = metadata_raw["/Parameters"]
    if "InitCondFile" in parameters:
        matchingattrs &= parameters["InitCondFile"] == "various"
    else:
        return CandidateStatus.NO
    header = metadata_raw["/Header"]
    matchingattrs &= header["BoxSize"] == 680000.0
    matchingattrs &= header["NumPart_Total"][1] == 1944529344
    matchingattrs &= header["NumPart_Total"][2] == 586952200

    if matchingattrs:
        valid = CandidateStatus.YES
    else:
        valid = CandidateStatus.NO
    return valid

dataset

ArepoCatalog

Bases: ArepoSnapshot

Dataset class for Arepo group catalogs.

Source code in src/scida/customs/arepo/dataset.py
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
class ArepoCatalog(ArepoSnapshot):
    """
    Dataset class for Arepo group catalogs.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize an ArepoCatalog object.

        Parameters
        ----------
        args
        kwargs
        """
        kwargs["iscatalog"] = True
        if "fileprefix" not in kwargs:
            kwargs["fileprefix"] = "groups"
        super().__init__(*args, **kwargs)

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path to use for instantiation of this class.

        Parameters
        ----------
        path: str or pathlib.Path
        args:
        kwargs:

        Returns
        -------
        CandidateStatus
        """
        kwargs["fileprefix"] = cls._get_fileprefix(path)
        valid = super().validate_path(path, *args, expect_grp=True, **kwargs)
        return valid
__init__(*args, **kwargs)

Initialize an ArepoCatalog object.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/customs/arepo/dataset.py
691
692
693
694
695
696
697
698
699
700
701
702
703
def __init__(self, *args, **kwargs):
    """
    Initialize an ArepoCatalog object.

    Parameters
    ----------
    args
    kwargs
    """
    kwargs["iscatalog"] = True
    if "fileprefix" not in kwargs:
        kwargs["fileprefix"] = "groups"
    super().__init__(*args, **kwargs)
validate_path(path, *args, **kwargs) classmethod

Validate a path to use for instantiation of this class.

Parameters:

Name Type Description Default
path Union[str, PathLike]
required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus
Source code in src/scida/customs/arepo/dataset.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path to use for instantiation of this class.

    Parameters
    ----------
    path: str or pathlib.Path
    args:
    kwargs:

    Returns
    -------
    CandidateStatus
    """
    kwargs["fileprefix"] = cls._get_fileprefix(path)
    valid = super().validate_path(path, *args, expect_grp=True, **kwargs)
    return valid
ArepoSnapshot

Bases: SpatialCartesian3DMixin, GadgetStyleSnapshot

Dataset class for Arepo snapshots.

Source code in src/scida/customs/arepo/dataset.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
class ArepoSnapshot(SpatialCartesian3DMixin, GadgetStyleSnapshot):
    """
    Dataset class for Arepo snapshots.
    """

    _fileprefix_catalog = "groups"

    def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
        """
        Initialize an ArepoSnapshot object.

        Parameters
        ----------
        path: str or pathlib.Path
            Path to snapshot, typically a directory containing multiple hdf5 files.
        chunksize:
            Chunksize to use for dask arrays. Can be "auto" to automatically determine chunksize.
        catalog:
            Path to group catalog. If None, the catalog is searched for in the parent directories.
        kwargs:
            Additional keyword arguments.
        """
        self.iscatalog = kwargs.pop("iscatalog", False)
        self.header = {}
        self.config = {}
        self._defaultunitfiles: List[str] = ["units/gadget_cosmological.yaml"]
        self.parameters = {}
        self._grouplengths = {}
        self._subhalolengths = {}
        # not needed for group catalogs as entries are back-to-back there, we will provide a property for this
        self._subhalooffsets = {}
        self.misc = {}  # for storing misc info
        prfx = kwargs.pop("fileprefix", None)
        if prfx is None:
            prfx = self._get_fileprefix(path)
        super().__init__(path, chunksize=chunksize, fileprefix=prfx, **kwargs)

        self.catalog = catalog
        if not self.iscatalog:
            if self.catalog is None:
                self.discover_catalog()
                # try to discover group catalog in parent directories.
            if self.catalog == "none":
                pass  # this string can be set to explicitly disable catalog
            elif self.catalog:
                catalog_cls = kwargs.get("catalog_cls", None)
                cosmological = False
                if hasattr(self, "_mixins") and "cosmology" in self._mixins:
                    cosmological = True
                self.load_catalog(
                    overwrite_cache=kwargs.get("overwrite_cache", False),
                    catalog_cls=catalog_cls,
                    units=self.withunits,
                    cosmological=cosmological,
                )

        # add aliases
        aliases = dict(
            PartType0=["gas", "baryons"],
            PartType1=["dm", "dark matter"],
            PartType2=["lowres", "lowres dm"],
            PartType3=["tracer", "tracers"],
            PartType4=["stars"],
            PartType5=["bh", "black holes"],
        )
        for k, lst in aliases.items():
            if k not in self.data:
                continue
            for v in lst:
                self.data.add_alias(v, k)

        # set metadata
        self._set_metadata()

        # add some default fields
        self.data.merge(fielddefs)

    def load_catalog(
        self, overwrite_cache=False, units=False, cosmological=False, catalog_cls=None
    ):
        """
        Load the group catalog.

        Parameters
        ----------
        kwargs: dict
            Keyword arguments passed to the catalog class.
        catalog_cls: type
            Class to use for the catalog. If None, the default catalog class is used.

        Returns
        -------
        None
        """
        virtualcache = False  # copy catalog for better performance
        # fileprefix = catalog_kwargs.get("fileprefix", self._fileprefix_catalog)
        prfx = self._get_fileprefix(self.catalog)

        # explicitly need to create unitaware class for catalog as needed
        # TODO: should just be determined from mixins of parent?
        if catalog_cls is None:
            cls = ArepoCatalog
        else:
            cls = catalog_cls
        withunits = units
        mixins = []
        if withunits:
            mixins += [UnitMixin]

        other_mixins = _determine_mixins(path=self.path)
        mixins += other_mixins
        if cosmological and CosmologyMixin not in mixins:
            mixins.append(CosmologyMixin)

        cls = create_datasetclass_with_mixins(cls, mixins)

        ureg = None
        if hasattr(self, "ureg"):
            ureg = self.ureg

        self.catalog = cls(
            self.catalog,
            overwrite_cache=overwrite_cache,
            virtualcache=virtualcache,
            fileprefix=prfx,
            units=self.withunits,
            ureg=ureg,
        )
        if "Redshift" in self.catalog.header and "Redshift" in self.header:
            z_catalog = self.catalog.header["Redshift"]
            z_snap = self.header["Redshift"]
            if not np.isclose(z_catalog, z_snap):
                raise ValueError(
                    "Redshift mismatch between snapshot and catalog: "
                    f"{z_snap:.2f} vs {z_catalog:.2f}"
                )

        # merge data
        self.merge_data(self.catalog)

        # first snapshots often do not have groups
        if "Group" in self.catalog.data:
            ngkeys = self.catalog.data["Group"].keys()
            if len(ngkeys) > 0:
                self.add_catalogIDs()

        # merge hints from snap and catalog
        self.merge_hints(self.catalog)

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path to use for instantiation of this class.

        Parameters
        ----------
        path: str or pathlib.Path
        args:
        kwargs:

        Returns
        -------
        CandidateStatus
        """
        valid = super().validate_path(path, *args, **kwargs)
        if valid.value > CandidateStatus.MAYBE.value:
            valid = CandidateStatus.MAYBE
        else:
            return valid
        # Arepo has no dedicated attribute to identify such runs.
        # lets just query a bunch of attributes that are present for arepo runs
        metadata_raw = load_metadata(path, **kwargs)
        matchingattrs = True
        matchingattrs &= "Git_commit" in metadata_raw["/Header"]
        # not existent for any arepo run?
        matchingattrs &= "Compactify_Version" not in metadata_raw["/Header"]

        if matchingattrs:
            valid = CandidateStatus.MAYBE

        return valid

    @ArepoSelector()
    def return_data(self) -> FieldContainer:
        """
        Return data object of this snapshot.

        Returns
        -------
        None
        """
        return super().return_data()

    def discover_catalog(self):
        """
        Discover the group catalog given the current path

        Returns
        -------
        None
        """
        p = str(self.path)
        # order of candidates matters. For Illustris "groups" must precede "fof_subhalo_tab"
        candidates = [
            p.replace("snapshot", "group"),
            p.replace("snapshot", "groups"),
            p.replace("snapdir", "groups").replace("snap", "groups"),
            p.replace("snapdir", "groups").replace("snap", "fof_subhalo_tab"),
        ]
        for candidate in candidates:
            if not os.path.exists(candidate):
                continue
            if candidate == self.path:
                continue
            self.catalog = candidate
            break

    def register_field(self, parttype: str, name: str = None, construct: bool = False):
        """
        Register a field.
        Parameters
        ----------
        parttype: str
            name of particle type
        name: str
            name of field
        construct: bool
            construct field immediately

        Returns
        -------
        None
        """
        num = part_type_num(parttype)
        if construct:  # TODO: introduce (immediate) construct option later
            raise NotImplementedError
        if num == -1:  # TODO: all particle species
            key = "all"
            raise NotImplementedError
        elif isinstance(num, int):
            key = "PartType" + str(num)
        else:
            key = parttype
        return super().register_field(key, name=name)

    def add_catalogIDs(self) -> None:
        """
        Add field for halo and subgroup IDs for all particle types.

        Returns
        -------
        None
        """
        # TODO: make these delayed objects and properly pass into (delayed?) numba functions:
        # https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-repeatedly-putting-large-inputs-into-delayed-calls

        maxint = np.iinfo(np.int64).max
        self.misc["unboundID"] = maxint

        # Group ID
        if "Group" not in self.data:  # can happen for empty catalogs
            for key in self.data:
                if not (key.startswith("PartType")):
                    continue
                uid = self.data[key]["uid"]
                self.data[key]["GroupID"] = self.misc["unboundID"] * da.ones_like(
                    uid, dtype=np.int64
                )
                self.data[key]["SubhaloID"] = self.misc["unboundID"] * da.ones_like(
                    uid, dtype=np.int64
                )
            return

        glen = self.data["Group"]["GroupLenType"]
        ngrp = glen.shape[0]
        da_halocelloffsets = da.concatenate(
            [
                np.zeros((1, 6), dtype=np.int64),
                da.cumsum(glen, axis=0, dtype=np.int64),
            ]
        )
        # remove last entry to match shapematch shape
        self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
            glen.chunks
        )
        halocelloffsets = da_halocelloffsets.rechunk(-1)

        index_unbound = self.misc["unboundID"]

        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            num = int(key[-1])
            if "uid" not in self.data[key]:
                continue  # can happen for empty containers
            gidx = self.data[key]["uid"]
            hidx = compute_haloindex(
                gidx, halocelloffsets[:, num], index_unbound=index_unbound
            )
            self.data[key]["GroupID"] = hidx

        # Subhalo ID
        if "Subhalo" not in self.data:  # can happen for empty catalogs
            for key in self.data:
                if not (key.startswith("PartType")):
                    continue
                self.data[key]["SubhaloID"] = -1 * da.ones_like(
                    da[key]["uid"], dtype=np.int64
                )
            return

        shnr_attr = "SubhaloGrNr"
        if shnr_attr not in self.data["Subhalo"]:
            shnr_attr = "SubhaloGroupNr"  # what MTNG does
        if shnr_attr not in self.data["Subhalo"]:
            raise ValueError(
                f"Could not find 'SubhaloGrNr' or 'SubhaloGroupNr' in {self.catalog}"
            )

        subhalogrnr = self.data["Subhalo"][shnr_attr]
        subhalocellcounts = self.data["Subhalo"]["SubhaloLenType"]

        # remove "units" for numba funcs
        if hasattr(subhalogrnr, "magnitude"):
            subhalogrnr = subhalogrnr.magnitude
        if hasattr(subhalocellcounts, "magnitude"):
            subhalocellcounts = subhalocellcounts.magnitude

        grp = self.data["Group"]
        if "GroupFirstSub" not in grp or "GroupNsubs" not in grp:
            # if not provided, we calculate:
            # "GroupFirstSub": First subhalo index for each halo
            # "GroupNsubs": Number of subhalos for each halo
            dlyd = delayed(get_shcounts_shcells)(subhalogrnr, ngrp)
            grp["GroupFirstSub"] = dask.compute(dlyd[1])[0]
            grp["GroupNsubs"] = dask.compute(dlyd[0])[0]

        # remove "units" for numba funcs
        grpfirstsub = grp["GroupFirstSub"]
        if hasattr(grpfirstsub, "magnitude"):
            grpfirstsub = grpfirstsub.magnitude
        grpnsubs = grp["GroupNsubs"]
        if hasattr(grpnsubs, "magnitude"):
            grpnsubs = grpnsubs.magnitude

        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            num = int(key[-1])
            pdata = self.data[key]
            if "uid" not in self.data[key]:
                continue  # can happen for empty containers
            gidx = pdata["uid"]

            # we need to make other dask arrays delayed,
            # map_block does not incorrectly infer output shape from these
            halocelloffsets_dlyd = delayed(halocelloffsets[:, num])
            grpfirstsub_dlyd = delayed(grpfirstsub)
            grpnsubs_dlyd = delayed(grpnsubs)
            subhalocellcounts_dlyd = delayed(subhalocellcounts[:, num])

            sidx = compute_localsubhaloindex(
                gidx,
                halocelloffsets_dlyd,
                grpfirstsub_dlyd,
                grpnsubs_dlyd,
                subhalocellcounts_dlyd,
                index_unbound=index_unbound,
            )

            pdata["LocalSubhaloID"] = sidx

            # reconstruct SubhaloID from Group's GroupFirstSub and LocalSubhaloID
            # should be easier to do it directly, but quicker to write down like this:

            # calculate first subhalo of each halo that a particle belongs to
            self.add_groupquantity_to_particles("GroupFirstSub", parttype=key)
            pdata["SubhaloID"] = pdata["GroupFirstSub"] + pdata["LocalSubhaloID"]
            pdata["SubhaloID"] = da.where(
                pdata["SubhaloID"] == index_unbound, index_unbound, pdata["SubhaloID"]
            )

    @computedecorator
    def map_group_operation(
        self,
        func,
        cpucost_halo=1e4,
        nchunks_min=None,
        chunksize_bytes=None,
        nmax=None,
        idxlist=None,
        objtype="halo",
    ):
        """
        Apply a function to each halo in the catalog.

        Parameters
        ----------
        objtype: str
            Type of object to process. Can be "halo" or "subhalo". Default: "halo"
        idxlist: Optional[np.ndarray]
            List of halo indices to process. If not provided, all halos are processed.
        func: function
            Function to apply to each halo. Must take a dictionary of arrays as input.
        cpucost_halo:
            "CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
            used for calculating the dask chunks. Default: 1e4
        nchunks_min: Optional[int]
            Minimum number of particles in a halo to process it. Default: None
        chunksize_bytes: Optional[int]
        nmax: Optional[int]
            Only process the first nmax halos.

        Returns
        -------
        None
        """
        dfltkwargs = get_kwargs(func)
        fieldnames = dfltkwargs.get("fieldnames", None)
        if fieldnames is None:
            fieldnames = get_args(func)
        parttype = dfltkwargs.get("parttype", "PartType0")
        entry_nbytes_in = np.sum([self.data[parttype][f][0].nbytes for f in fieldnames])
        objtype = grp_type_str(objtype)
        if objtype == "halo":
            lengths = self.get_grouplengths(parttype=parttype)
            offsets = self.get_groupoffsets(parttype=parttype)
        elif objtype == "subhalo":
            lengths = self.get_subhalolengths(parttype=parttype)
            offsets = self.get_subhalooffsets(parttype=parttype)
        else:
            raise ValueError(f"objtype must be 'halo' or 'subhalo', not {objtype}")
        arrdict = self.data[parttype]
        return map_group_operation(
            func,
            offsets,
            lengths,
            arrdict,
            cpucost_halo=cpucost_halo,
            nchunks_min=nchunks_min,
            chunksize_bytes=chunksize_bytes,
            entry_nbytes_in=entry_nbytes_in,
            nmax=nmax,
            idxlist=idxlist,
        )

    def add_groupquantity_to_particles(self, name, parttype="PartType0"):
        """
        Map a quantity from the group catalog to the particles based on a particle's group index.

        Parameters
        ----------
        name: str
            Name of quantity to map
        parttype: str
            Name of particle type

        Returns
        -------
        None
        """
        pdata = self.data[parttype]
        assert (
            name not in pdata
        )  # we simply map the name from Group to Particle for now. Should work (?)
        glen = self.data["Group"]["GroupLenType"]
        da_halocelloffsets = da.concatenate(
            [np.zeros((1, 6), dtype=np.int64), da.cumsum(glen, axis=0)]
        )
        if "GroupOffsetsType" not in self.data["Group"]:
            self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
                glen.chunks
            )  # remove last entry to match shape
        halocelloffsets = da_halocelloffsets.compute()

        gidx = pdata["uid"]
        num = int(parttype[-1])
        hquantity = compute_haloquantity(
            gidx, halocelloffsets[:, num], self.data["Group"][name]
        )
        pdata[name] = hquantity

    def get_grouplengths(self, parttype="PartType0"):
        """
        Get the lengths, i.e. the total number of particles, of a given type in all halos.

        Parameters
        ----------
        parttype: str
            Name of particle type

        Returns
        -------
        np.ndarray
        """
        pnum = part_type_num(parttype)
        ptype = "PartType%i" % pnum
        if ptype not in self._grouplengths:
            lengths = self.data["Group"]["GroupLenType"][:, pnum].compute()
            if isinstance(lengths, pint.Quantity):
                lengths = lengths.magnitude
            self._grouplengths[ptype] = lengths
        return self._grouplengths[ptype]

    def get_groupoffsets(self, parttype="PartType0"):
        """
        Get the array index offset of the first particle of a given type in each halo.

        Parameters
        ----------
        parttype: str
            Name of particle type

        Returns
        -------
        np.ndarray
        """
        if parttype not in self._grouplengths:
            # need to calculate group lengths first
            self.get_grouplengths(parttype=parttype)
        return self._groupoffsets[parttype]

    @property
    def _groupoffsets(self):
        lengths = self._grouplengths
        offsets = {
            k: np.concatenate([[0], np.cumsum(v)[:-1]]) for k, v in lengths.items()
        }
        return offsets

    def get_subhalolengths(self, parttype="PartType0"):
        """
        Get the lengths, i.e. the total number of particles, of a given type in all subhalos.

        Parameters
        ----------
        parttype: str

        Returns
        -------
        np.ndarray
        """
        pnum = part_type_num(parttype)
        ptype = "PartType%i" % pnum
        if ptype in self._subhalolengths:
            return self._subhalolengths[ptype]
        lengths = self.data["Subhalo"]["SubhaloLenType"][:, pnum].compute()
        if isinstance(lengths, pint.Quantity):
            lengths = lengths.magnitude
        self._subhalolengths[ptype] = lengths
        return self._subhalolengths[ptype]

    def get_subhalooffsets(self, parttype="PartType0"):
        """
        Get the array index offset of the first particle of a given type in each subhalo.

        Parameters
        ----------
        parttype: str

        Returns
        -------
        np.ndarray
        """

        pnum = part_type_num(parttype)
        ptype = "PartType%i" % pnum
        if ptype in self._subhalooffsets:
            return self._subhalooffsets[ptype]  # use cached result
        goffsets = self.get_groupoffsets(ptype)
        shgrnr = self.data["Subhalo"]["SubhaloGrNr"]
        # calculate the index of the first particle for the central subhalo of each subhalos's parent halo
        shoffset_central = goffsets[shgrnr]

        grpfirstsub = self.data["Group"]["GroupFirstSub"]
        shlens = self.get_subhalolengths(ptype)
        shoffsets = np.concatenate([[0], np.cumsum(shlens)[:-1]])

        # particle offset for the first subhalo of each group that a subhalo belongs to
        shfirstshoffset = shoffsets[grpfirstsub[shgrnr]]

        # "LocalSubhaloOffset": particle offset of each subhalo in the parent group
        shoffset_local = shoffsets - shfirstshoffset

        # "SubhaloOffset": particle offset of each subhalo in the simulation
        offsets = shoffset_central + shoffset_local

        self._subhalooffsets[ptype] = offsets

        return offsets

    def grouped(
        self,
        fields: Union[str, da.Array, List[str], Dict[str, da.Array]] = "",
        parttype="PartType0",
        objtype="halo",
    ):
        """
        Create a GroupAwareOperation object for applying operations to groups.

        Parameters
        ----------
        fields: Union[str, da.Array, List[str], Dict[str, da.Array]]
            Fields to pass to the operation. Can be a string, a dask array, a list of strings or a dictionary of dask arrays.
        parttype: str
            Particle type to operate on.
        objtype: str
            Type of object to operate on. Can be "halo" or "subhalo". Default: "halo"

        Returns
        -------
        GroupAwareOperation
        """
        inputfields = None
        if isinstance(fields, str):
            if fields == "":  # if nothing is specified, we pass all we have.
                arrdict = self.data[parttype]
            else:
                arrdict = dict(field=self.data[parttype][fields])
                inputfields = [fields]
        elif isinstance(fields, da.Array) or isinstance(fields, pint.Quantity):
            arrdict = dict(daskarr=fields)
            inputfields = [fields.name]
        elif isinstance(fields, list):
            arrdict = {k: self.data[parttype][k] for k in fields}
            inputfields = fields
        elif isinstance(fields, dict):
            arrdict = {}
            arrdict.update(**fields)
            inputfields = list(arrdict.keys())
        else:
            raise ValueError("Unknown input type '%s'." % type(fields))
        objtype = grp_type_str(objtype)
        if objtype == "halo":
            offsets = self.get_groupoffsets(parttype=parttype)
            lengths = self.get_grouplengths(parttype=parttype)
        elif objtype == "subhalo":
            offsets = self.get_subhalooffsets(parttype=parttype)
            lengths = self.get_subhalolengths(parttype=parttype)
        else:
            raise ValueError("Unknown object type '%s'." % objtype)

        gop = GroupAwareOperation(
            offsets,
            lengths,
            arrdict,
            inputfields=inputfields,
        )
        return gop
__init__(path, chunksize='auto', catalog=None, **kwargs)

Initialize an ArepoSnapshot object.

Parameters:

Name Type Description Default
path

Path to snapshot, typically a directory containing multiple hdf5 files.

required
chunksize

Chunksize to use for dask arrays. Can be "auto" to automatically determine chunksize.

'auto'
catalog

Path to group catalog. If None, the catalog is searched for in the parent directories.

None
kwargs

Additional keyword arguments.

{}
Source code in src/scida/customs/arepo/dataset.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
    """
    Initialize an ArepoSnapshot object.

    Parameters
    ----------
    path: str or pathlib.Path
        Path to snapshot, typically a directory containing multiple hdf5 files.
    chunksize:
        Chunksize to use for dask arrays. Can be "auto" to automatically determine chunksize.
    catalog:
        Path to group catalog. If None, the catalog is searched for in the parent directories.
    kwargs:
        Additional keyword arguments.
    """
    self.iscatalog = kwargs.pop("iscatalog", False)
    self.header = {}
    self.config = {}
    self._defaultunitfiles: List[str] = ["units/gadget_cosmological.yaml"]
    self.parameters = {}
    self._grouplengths = {}
    self._subhalolengths = {}
    # not needed for group catalogs as entries are back-to-back there, we will provide a property for this
    self._subhalooffsets = {}
    self.misc = {}  # for storing misc info
    prfx = kwargs.pop("fileprefix", None)
    if prfx is None:
        prfx = self._get_fileprefix(path)
    super().__init__(path, chunksize=chunksize, fileprefix=prfx, **kwargs)

    self.catalog = catalog
    if not self.iscatalog:
        if self.catalog is None:
            self.discover_catalog()
            # try to discover group catalog in parent directories.
        if self.catalog == "none":
            pass  # this string can be set to explicitly disable catalog
        elif self.catalog:
            catalog_cls = kwargs.get("catalog_cls", None)
            cosmological = False
            if hasattr(self, "_mixins") and "cosmology" in self._mixins:
                cosmological = True
            self.load_catalog(
                overwrite_cache=kwargs.get("overwrite_cache", False),
                catalog_cls=catalog_cls,
                units=self.withunits,
                cosmological=cosmological,
            )

    # add aliases
    aliases = dict(
        PartType0=["gas", "baryons"],
        PartType1=["dm", "dark matter"],
        PartType2=["lowres", "lowres dm"],
        PartType3=["tracer", "tracers"],
        PartType4=["stars"],
        PartType5=["bh", "black holes"],
    )
    for k, lst in aliases.items():
        if k not in self.data:
            continue
        for v in lst:
            self.data.add_alias(v, k)

    # set metadata
    self._set_metadata()

    # add some default fields
    self.data.merge(fielddefs)
add_catalogIDs()

Add field for halo and subgroup IDs for all particle types.

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def add_catalogIDs(self) -> None:
    """
    Add field for halo and subgroup IDs for all particle types.

    Returns
    -------
    None
    """
    # TODO: make these delayed objects and properly pass into (delayed?) numba functions:
    # https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-repeatedly-putting-large-inputs-into-delayed-calls

    maxint = np.iinfo(np.int64).max
    self.misc["unboundID"] = maxint

    # Group ID
    if "Group" not in self.data:  # can happen for empty catalogs
        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            uid = self.data[key]["uid"]
            self.data[key]["GroupID"] = self.misc["unboundID"] * da.ones_like(
                uid, dtype=np.int64
            )
            self.data[key]["SubhaloID"] = self.misc["unboundID"] * da.ones_like(
                uid, dtype=np.int64
            )
        return

    glen = self.data["Group"]["GroupLenType"]
    ngrp = glen.shape[0]
    da_halocelloffsets = da.concatenate(
        [
            np.zeros((1, 6), dtype=np.int64),
            da.cumsum(glen, axis=0, dtype=np.int64),
        ]
    )
    # remove last entry to match shapematch shape
    self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
        glen.chunks
    )
    halocelloffsets = da_halocelloffsets.rechunk(-1)

    index_unbound = self.misc["unboundID"]

    for key in self.data:
        if not (key.startswith("PartType")):
            continue
        num = int(key[-1])
        if "uid" not in self.data[key]:
            continue  # can happen for empty containers
        gidx = self.data[key]["uid"]
        hidx = compute_haloindex(
            gidx, halocelloffsets[:, num], index_unbound=index_unbound
        )
        self.data[key]["GroupID"] = hidx

    # Subhalo ID
    if "Subhalo" not in self.data:  # can happen for empty catalogs
        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            self.data[key]["SubhaloID"] = -1 * da.ones_like(
                da[key]["uid"], dtype=np.int64
            )
        return

    shnr_attr = "SubhaloGrNr"
    if shnr_attr not in self.data["Subhalo"]:
        shnr_attr = "SubhaloGroupNr"  # what MTNG does
    if shnr_attr not in self.data["Subhalo"]:
        raise ValueError(
            f"Could not find 'SubhaloGrNr' or 'SubhaloGroupNr' in {self.catalog}"
        )

    subhalogrnr = self.data["Subhalo"][shnr_attr]
    subhalocellcounts = self.data["Subhalo"]["SubhaloLenType"]

    # remove "units" for numba funcs
    if hasattr(subhalogrnr, "magnitude"):
        subhalogrnr = subhalogrnr.magnitude
    if hasattr(subhalocellcounts, "magnitude"):
        subhalocellcounts = subhalocellcounts.magnitude

    grp = self.data["Group"]
    if "GroupFirstSub" not in grp or "GroupNsubs" not in grp:
        # if not provided, we calculate:
        # "GroupFirstSub": First subhalo index for each halo
        # "GroupNsubs": Number of subhalos for each halo
        dlyd = delayed(get_shcounts_shcells)(subhalogrnr, ngrp)
        grp["GroupFirstSub"] = dask.compute(dlyd[1])[0]
        grp["GroupNsubs"] = dask.compute(dlyd[0])[0]

    # remove "units" for numba funcs
    grpfirstsub = grp["GroupFirstSub"]
    if hasattr(grpfirstsub, "magnitude"):
        grpfirstsub = grpfirstsub.magnitude
    grpnsubs = grp["GroupNsubs"]
    if hasattr(grpnsubs, "magnitude"):
        grpnsubs = grpnsubs.magnitude

    for key in self.data:
        if not (key.startswith("PartType")):
            continue
        num = int(key[-1])
        pdata = self.data[key]
        if "uid" not in self.data[key]:
            continue  # can happen for empty containers
        gidx = pdata["uid"]

        # we need to make other dask arrays delayed,
        # map_block does not incorrectly infer output shape from these
        halocelloffsets_dlyd = delayed(halocelloffsets[:, num])
        grpfirstsub_dlyd = delayed(grpfirstsub)
        grpnsubs_dlyd = delayed(grpnsubs)
        subhalocellcounts_dlyd = delayed(subhalocellcounts[:, num])

        sidx = compute_localsubhaloindex(
            gidx,
            halocelloffsets_dlyd,
            grpfirstsub_dlyd,
            grpnsubs_dlyd,
            subhalocellcounts_dlyd,
            index_unbound=index_unbound,
        )

        pdata["LocalSubhaloID"] = sidx

        # reconstruct SubhaloID from Group's GroupFirstSub and LocalSubhaloID
        # should be easier to do it directly, but quicker to write down like this:

        # calculate first subhalo of each halo that a particle belongs to
        self.add_groupquantity_to_particles("GroupFirstSub", parttype=key)
        pdata["SubhaloID"] = pdata["GroupFirstSub"] + pdata["LocalSubhaloID"]
        pdata["SubhaloID"] = da.where(
            pdata["SubhaloID"] == index_unbound, index_unbound, pdata["SubhaloID"]
        )
add_groupquantity_to_particles(name, parttype='PartType0')

Map a quantity from the group catalog to the particles based on a particle's group index.

Parameters:

Name Type Description Default
name

Name of quantity to map

required
parttype

Name of particle type

'PartType0'

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def add_groupquantity_to_particles(self, name, parttype="PartType0"):
    """
    Map a quantity from the group catalog to the particles based on a particle's group index.

    Parameters
    ----------
    name: str
        Name of quantity to map
    parttype: str
        Name of particle type

    Returns
    -------
    None
    """
    pdata = self.data[parttype]
    assert (
        name not in pdata
    )  # we simply map the name from Group to Particle for now. Should work (?)
    glen = self.data["Group"]["GroupLenType"]
    da_halocelloffsets = da.concatenate(
        [np.zeros((1, 6), dtype=np.int64), da.cumsum(glen, axis=0)]
    )
    if "GroupOffsetsType" not in self.data["Group"]:
        self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
            glen.chunks
        )  # remove last entry to match shape
    halocelloffsets = da_halocelloffsets.compute()

    gidx = pdata["uid"]
    num = int(parttype[-1])
    hquantity = compute_haloquantity(
        gidx, halocelloffsets[:, num], self.data["Group"][name]
    )
    pdata[name] = hquantity
discover_catalog()

Discover the group catalog given the current path

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def discover_catalog(self):
    """
    Discover the group catalog given the current path

    Returns
    -------
    None
    """
    p = str(self.path)
    # order of candidates matters. For Illustris "groups" must precede "fof_subhalo_tab"
    candidates = [
        p.replace("snapshot", "group"),
        p.replace("snapshot", "groups"),
        p.replace("snapdir", "groups").replace("snap", "groups"),
        p.replace("snapdir", "groups").replace("snap", "fof_subhalo_tab"),
    ]
    for candidate in candidates:
        if not os.path.exists(candidate):
            continue
        if candidate == self.path:
            continue
        self.catalog = candidate
        break
get_grouplengths(parttype='PartType0')

Get the lengths, i.e. the total number of particles, of a given type in all halos.

Parameters:

Name Type Description Default
parttype

Name of particle type

'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def get_grouplengths(self, parttype="PartType0"):
    """
    Get the lengths, i.e. the total number of particles, of a given type in all halos.

    Parameters
    ----------
    parttype: str
        Name of particle type

    Returns
    -------
    np.ndarray
    """
    pnum = part_type_num(parttype)
    ptype = "PartType%i" % pnum
    if ptype not in self._grouplengths:
        lengths = self.data["Group"]["GroupLenType"][:, pnum].compute()
        if isinstance(lengths, pint.Quantity):
            lengths = lengths.magnitude
        self._grouplengths[ptype] = lengths
    return self._grouplengths[ptype]
get_groupoffsets(parttype='PartType0')

Get the array index offset of the first particle of a given type in each halo.

Parameters:

Name Type Description Default
parttype

Name of particle type

'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def get_groupoffsets(self, parttype="PartType0"):
    """
    Get the array index offset of the first particle of a given type in each halo.

    Parameters
    ----------
    parttype: str
        Name of particle type

    Returns
    -------
    np.ndarray
    """
    if parttype not in self._grouplengths:
        # need to calculate group lengths first
        self.get_grouplengths(parttype=parttype)
    return self._groupoffsets[parttype]
get_subhalolengths(parttype='PartType0')

Get the lengths, i.e. the total number of particles, of a given type in all subhalos.

Parameters:

Name Type Description Default
parttype
'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def get_subhalolengths(self, parttype="PartType0"):
    """
    Get the lengths, i.e. the total number of particles, of a given type in all subhalos.

    Parameters
    ----------
    parttype: str

    Returns
    -------
    np.ndarray
    """
    pnum = part_type_num(parttype)
    ptype = "PartType%i" % pnum
    if ptype in self._subhalolengths:
        return self._subhalolengths[ptype]
    lengths = self.data["Subhalo"]["SubhaloLenType"][:, pnum].compute()
    if isinstance(lengths, pint.Quantity):
        lengths = lengths.magnitude
    self._subhalolengths[ptype] = lengths
    return self._subhalolengths[ptype]
get_subhalooffsets(parttype='PartType0')

Get the array index offset of the first particle of a given type in each subhalo.

Parameters:

Name Type Description Default
parttype
'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def get_subhalooffsets(self, parttype="PartType0"):
    """
    Get the array index offset of the first particle of a given type in each subhalo.

    Parameters
    ----------
    parttype: str

    Returns
    -------
    np.ndarray
    """

    pnum = part_type_num(parttype)
    ptype = "PartType%i" % pnum
    if ptype in self._subhalooffsets:
        return self._subhalooffsets[ptype]  # use cached result
    goffsets = self.get_groupoffsets(ptype)
    shgrnr = self.data["Subhalo"]["SubhaloGrNr"]
    # calculate the index of the first particle for the central subhalo of each subhalos's parent halo
    shoffset_central = goffsets[shgrnr]

    grpfirstsub = self.data["Group"]["GroupFirstSub"]
    shlens = self.get_subhalolengths(ptype)
    shoffsets = np.concatenate([[0], np.cumsum(shlens)[:-1]])

    # particle offset for the first subhalo of each group that a subhalo belongs to
    shfirstshoffset = shoffsets[grpfirstsub[shgrnr]]

    # "LocalSubhaloOffset": particle offset of each subhalo in the parent group
    shoffset_local = shoffsets - shfirstshoffset

    # "SubhaloOffset": particle offset of each subhalo in the simulation
    offsets = shoffset_central + shoffset_local

    self._subhalooffsets[ptype] = offsets

    return offsets
grouped(fields='', parttype='PartType0', objtype='halo')

Create a GroupAwareOperation object for applying operations to groups.

Parameters:

Name Type Description Default
fields Union[str, Array, List[str], Dict[str, Array]]

Fields to pass to the operation. Can be a string, a dask array, a list of strings or a dictionary of dask arrays.

''
parttype

Particle type to operate on.

'PartType0'
objtype

Type of object to operate on. Can be "halo" or "subhalo". Default: "halo"

'halo'

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def grouped(
    self,
    fields: Union[str, da.Array, List[str], Dict[str, da.Array]] = "",
    parttype="PartType0",
    objtype="halo",
):
    """
    Create a GroupAwareOperation object for applying operations to groups.

    Parameters
    ----------
    fields: Union[str, da.Array, List[str], Dict[str, da.Array]]
        Fields to pass to the operation. Can be a string, a dask array, a list of strings or a dictionary of dask arrays.
    parttype: str
        Particle type to operate on.
    objtype: str
        Type of object to operate on. Can be "halo" or "subhalo". Default: "halo"

    Returns
    -------
    GroupAwareOperation
    """
    inputfields = None
    if isinstance(fields, str):
        if fields == "":  # if nothing is specified, we pass all we have.
            arrdict = self.data[parttype]
        else:
            arrdict = dict(field=self.data[parttype][fields])
            inputfields = [fields]
    elif isinstance(fields, da.Array) or isinstance(fields, pint.Quantity):
        arrdict = dict(daskarr=fields)
        inputfields = [fields.name]
    elif isinstance(fields, list):
        arrdict = {k: self.data[parttype][k] for k in fields}
        inputfields = fields
    elif isinstance(fields, dict):
        arrdict = {}
        arrdict.update(**fields)
        inputfields = list(arrdict.keys())
    else:
        raise ValueError("Unknown input type '%s'." % type(fields))
    objtype = grp_type_str(objtype)
    if objtype == "halo":
        offsets = self.get_groupoffsets(parttype=parttype)
        lengths = self.get_grouplengths(parttype=parttype)
    elif objtype == "subhalo":
        offsets = self.get_subhalooffsets(parttype=parttype)
        lengths = self.get_subhalolengths(parttype=parttype)
    else:
        raise ValueError("Unknown object type '%s'." % objtype)

    gop = GroupAwareOperation(
        offsets,
        lengths,
        arrdict,
        inputfields=inputfields,
    )
    return gop
load_catalog(overwrite_cache=False, units=False, cosmological=False, catalog_cls=None)

Load the group catalog.

Parameters:

Name Type Description Default
kwargs

Keyword arguments passed to the catalog class.

required
catalog_cls

Class to use for the catalog. If None, the default catalog class is used.

None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def load_catalog(
    self, overwrite_cache=False, units=False, cosmological=False, catalog_cls=None
):
    """
    Load the group catalog.

    Parameters
    ----------
    kwargs: dict
        Keyword arguments passed to the catalog class.
    catalog_cls: type
        Class to use for the catalog. If None, the default catalog class is used.

    Returns
    -------
    None
    """
    virtualcache = False  # copy catalog for better performance
    # fileprefix = catalog_kwargs.get("fileprefix", self._fileprefix_catalog)
    prfx = self._get_fileprefix(self.catalog)

    # explicitly need to create unitaware class for catalog as needed
    # TODO: should just be determined from mixins of parent?
    if catalog_cls is None:
        cls = ArepoCatalog
    else:
        cls = catalog_cls
    withunits = units
    mixins = []
    if withunits:
        mixins += [UnitMixin]

    other_mixins = _determine_mixins(path=self.path)
    mixins += other_mixins
    if cosmological and CosmologyMixin not in mixins:
        mixins.append(CosmologyMixin)

    cls = create_datasetclass_with_mixins(cls, mixins)

    ureg = None
    if hasattr(self, "ureg"):
        ureg = self.ureg

    self.catalog = cls(
        self.catalog,
        overwrite_cache=overwrite_cache,
        virtualcache=virtualcache,
        fileprefix=prfx,
        units=self.withunits,
        ureg=ureg,
    )
    if "Redshift" in self.catalog.header and "Redshift" in self.header:
        z_catalog = self.catalog.header["Redshift"]
        z_snap = self.header["Redshift"]
        if not np.isclose(z_catalog, z_snap):
            raise ValueError(
                "Redshift mismatch between snapshot and catalog: "
                f"{z_snap:.2f} vs {z_catalog:.2f}"
            )

    # merge data
    self.merge_data(self.catalog)

    # first snapshots often do not have groups
    if "Group" in self.catalog.data:
        ngkeys = self.catalog.data["Group"].keys()
        if len(ngkeys) > 0:
            self.add_catalogIDs()

    # merge hints from snap and catalog
    self.merge_hints(self.catalog)
map_group_operation(func, cpucost_halo=10000.0, nchunks_min=None, chunksize_bytes=None, nmax=None, idxlist=None, objtype='halo')

Apply a function to each halo in the catalog.

Parameters:

Name Type Description Default
objtype

Type of object to process. Can be "halo" or "subhalo". Default: "halo"

'halo'
idxlist

List of halo indices to process. If not provided, all halos are processed.

None
func

Function to apply to each halo. Must take a dictionary of arrays as input.

required
cpucost_halo

"CPU cost" of processing a single halo. This is a relative value to the processing time per input particle used for calculating the dask chunks. Default: 1e4

10000.0
nchunks_min

Minimum number of particles in a halo to process it. Default: None

None
chunksize_bytes
None
nmax

Only process the first nmax halos.

None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
@computedecorator
def map_group_operation(
    self,
    func,
    cpucost_halo=1e4,
    nchunks_min=None,
    chunksize_bytes=None,
    nmax=None,
    idxlist=None,
    objtype="halo",
):
    """
    Apply a function to each halo in the catalog.

    Parameters
    ----------
    objtype: str
        Type of object to process. Can be "halo" or "subhalo". Default: "halo"
    idxlist: Optional[np.ndarray]
        List of halo indices to process. If not provided, all halos are processed.
    func: function
        Function to apply to each halo. Must take a dictionary of arrays as input.
    cpucost_halo:
        "CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
        used for calculating the dask chunks. Default: 1e4
    nchunks_min: Optional[int]
        Minimum number of particles in a halo to process it. Default: None
    chunksize_bytes: Optional[int]
    nmax: Optional[int]
        Only process the first nmax halos.

    Returns
    -------
    None
    """
    dfltkwargs = get_kwargs(func)
    fieldnames = dfltkwargs.get("fieldnames", None)
    if fieldnames is None:
        fieldnames = get_args(func)
    parttype = dfltkwargs.get("parttype", "PartType0")
    entry_nbytes_in = np.sum([self.data[parttype][f][0].nbytes for f in fieldnames])
    objtype = grp_type_str(objtype)
    if objtype == "halo":
        lengths = self.get_grouplengths(parttype=parttype)
        offsets = self.get_groupoffsets(parttype=parttype)
    elif objtype == "subhalo":
        lengths = self.get_subhalolengths(parttype=parttype)
        offsets = self.get_subhalooffsets(parttype=parttype)
    else:
        raise ValueError(f"objtype must be 'halo' or 'subhalo', not {objtype}")
    arrdict = self.data[parttype]
    return map_group_operation(
        func,
        offsets,
        lengths,
        arrdict,
        cpucost_halo=cpucost_halo,
        nchunks_min=nchunks_min,
        chunksize_bytes=chunksize_bytes,
        entry_nbytes_in=entry_nbytes_in,
        nmax=nmax,
        idxlist=idxlist,
    )
register_field(parttype, name=None, construct=False)

Register a field.

Parameters:

Name Type Description Default
parttype str

name of particle type

required
name str

name of field

None
construct bool

construct field immediately

False

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def register_field(self, parttype: str, name: str = None, construct: bool = False):
    """
    Register a field.
    Parameters
    ----------
    parttype: str
        name of particle type
    name: str
        name of field
    construct: bool
        construct field immediately

    Returns
    -------
    None
    """
    num = part_type_num(parttype)
    if construct:  # TODO: introduce (immediate) construct option later
        raise NotImplementedError
    if num == -1:  # TODO: all particle species
        key = "all"
        raise NotImplementedError
    elif isinstance(num, int):
        key = "PartType" + str(num)
    else:
        key = parttype
    return super().register_field(key, name=name)
return_data()

Return data object of this snapshot.

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
217
218
219
220
221
222
223
224
225
226
@ArepoSelector()
def return_data(self) -> FieldContainer:
    """
    Return data object of this snapshot.

    Returns
    -------
    None
    """
    return super().return_data()
validate_path(path, *args, **kwargs) classmethod

Validate a path to use for instantiation of this class.

Parameters:

Name Type Description Default
path Union[str, PathLike]
required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus
Source code in src/scida/customs/arepo/dataset.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path to use for instantiation of this class.

    Parameters
    ----------
    path: str or pathlib.Path
    args:
    kwargs:

    Returns
    -------
    CandidateStatus
    """
    valid = super().validate_path(path, *args, **kwargs)
    if valid.value > CandidateStatus.MAYBE.value:
        valid = CandidateStatus.MAYBE
    else:
        return valid
    # Arepo has no dedicated attribute to identify such runs.
    # lets just query a bunch of attributes that are present for arepo runs
    metadata_raw = load_metadata(path, **kwargs)
    matchingattrs = True
    matchingattrs &= "Git_commit" in metadata_raw["/Header"]
    # not existent for any arepo run?
    matchingattrs &= "Compactify_Version" not in metadata_raw["/Header"]

    if matchingattrs:
        valid = CandidateStatus.MAYBE

    return valid
ChainOps

Chain operations together.

Source code in src/scida/customs/arepo/dataset.py
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
class ChainOps:
    """
    Chain operations together.
    """

    def __init__(self, *funcs):
        """
        Initialize a ChainOps object.

        Parameters
        ----------
        funcs: List[function]
            Functions to chain together.
        """
        self.funcs = funcs
        self.kwargs = get_kwargs(
            funcs[-1]
        )  # so we can pass info from kwargs to map_halo_operation
        if self.kwargs.get("dtype") is None:
            self.kwargs["dtype"] = float

        def chained_call(*args):
            cf = None
            for i, f in enumerate(funcs):
                # first chain element can be multiple fields. treat separately
                if i == 0:
                    cf = f(*args)
                else:
                    cf = f(cf)
            return cf

        self.call = chained_call

    def __call__(self, *args, **kwargs):
        return self.call(*args, **kwargs)
__init__(*funcs)

Initialize a ChainOps object.

Parameters:

Name Type Description Default
funcs

Functions to chain together.

()
Source code in src/scida/customs/arepo/dataset.py
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def __init__(self, *funcs):
    """
    Initialize a ChainOps object.

    Parameters
    ----------
    funcs: List[function]
        Functions to chain together.
    """
    self.funcs = funcs
    self.kwargs = get_kwargs(
        funcs[-1]
    )  # so we can pass info from kwargs to map_halo_operation
    if self.kwargs.get("dtype") is None:
        self.kwargs["dtype"] = float

    def chained_call(*args):
        cf = None
        for i, f in enumerate(funcs):
            # first chain element can be multiple fields. treat separately
            if i == 0:
                cf = f(*args)
            else:
                cf = f(cf)
        return cf

    self.call = chained_call
GroupAwareOperation

Class for applying operations to groups.

Source code in src/scida/customs/arepo/dataset.py
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
class GroupAwareOperation:
    """
    Class for applying operations to groups.
    """

    opfuncs = dict(min=np.min, max=np.max, sum=np.sum, half=lambda x: x[::2])
    finalops = {"min", "max", "sum"}
    __slots__ = (
        "arrs",
        "ops",
        "offsets",
        "lengths",
        "final",
        "inputfields",
        "opfuncs_custom",
    )

    def __init__(
        self,
        offsets: NDArray,
        lengths: NDArray,
        arrs: Dict[str, da.Array],
        ops=None,
        inputfields=None,
    ):
        self.offsets = offsets
        self.lengths = lengths
        self.arrs = arrs
        self.opfuncs_custom = {}
        self.final = False
        self.inputfields = inputfields
        if ops is None:
            self.ops = []
        else:
            self.ops = ops

    def chain(self, add_op=None, final=False):
        """
        Chain another operation to this one.

        Parameters
        ----------
        add_op: str or function
            Operation to add. Can be a string (e.g. "min", "max", "sum") or a function.
        final: bool
            Whether this is the final operation in the chain.

        Returns
        -------
        GroupAwareOperation
        """
        if self.final:
            raise ValueError("Cannot chain any additional operation.")
        c = copy.copy(self)
        c.final = final
        c.opfuncs_custom = self.opfuncs_custom
        if add_op is not None:
            if isinstance(add_op, str):
                c.ops.append(add_op)
            elif callable(add_op):
                name = "custom" + str(len(self.opfuncs_custom) + 1)
                c.opfuncs_custom[name] = add_op
                c.ops.append(name)
            else:
                raise ValueError("Unknown operation of type '%s'" % str(type(add_op)))
        return c

    def min(self, field=None):
        """
        Get the minimum value for each group member.
        """
        if field is not None:
            if self.inputfields is not None:
                raise ValueError("Cannot change input field anymore.")
            self.inputfields = [field]
        return self.chain(add_op="min", final=True)

    def max(self, field=None):
        """
        Get the maximum value for each group member.
        """
        if field is not None:
            if self.inputfields is not None:
                raise ValueError("Cannot change input field anymore.")
            self.inputfields = [field]
        return self.chain(add_op="max", final=True)

    def sum(self, field=None):
        """
        Sum the values for each group member.
        """
        if field is not None:
            if self.inputfields is not None:
                raise ValueError("Cannot change input field anymore.")
            self.inputfields = [field]
        return self.chain(add_op="sum", final=True)

    def half(self):
        """
        Half the number of particles in each group member. For testing purposes.

        Returns
        -------
        GroupAwareOperation
        """
        return self.chain(add_op="half", final=False)

    def apply(self, func, final=False):
        """
        Apply a passed function.

        Parameters
        ----------
        func: function
            Function to apply.
        final: bool
            Whether this is the final operation in the chain.

        Returns
        -------
        GroupAwareOperation
        """
        return self.chain(add_op=func, final=final)

    def __copy__(self):
        # overwrite method so that copy holds a new ops list.
        c = type(self)(
            self.offsets,
            self.lengths,
            self.arrs,
            ops=list(self.ops),
            inputfields=self.inputfields,
        )
        return c

    def evaluate(self, nmax=None, idxlist=None, compute=True):
        """
        Evaluate the operation.

        Parameters
        ----------
        nmax: Optional[int]
            Maximum number of halos to process.
        idxlist: Optional[np.ndarray]
            List of halo indices to process. If not provided, (and nmax not set) all halos are processed.
        compute: bool
            Whether to compute the result immediately or return a dask object to compute later.

        Returns
        -------

        """
        # TODO: figure out return type
        # final operations: those that can only be at end of chain
        # intermediate operations: those that can only be prior to end of chain
        funcdict = dict()
        funcdict.update(**self.opfuncs)
        funcdict.update(**self.opfuncs_custom)

        func = ChainOps(*[funcdict[k] for k in self.ops])

        fieldnames = list(self.arrs.keys())
        if self.inputfields is None:
            opname = self.ops[0]
            if opname.startswith("custom"):
                dfltkwargs = get_kwargs(self.opfuncs_custom[opname])
                fieldnames = dfltkwargs.get("fieldnames", None)
                if isinstance(fieldnames, str):
                    fieldnames = [fieldnames]
                if fieldnames is None:
                    raise ValueError(
                        "Either pass fields to grouped(fields=...) "
                        "or specify fieldnames=... in applied func."
                    )
            else:
                raise ValueError(
                    "Specify field to operate on in operation or grouped()."
                )

        res = map_group_operation(
            func,
            self.offsets,
            self.lengths,
            self.arrs,
            fieldnames=fieldnames,
            nmax=nmax,
            idxlist=idxlist,
        )
        if compute:
            res = res.compute()
        return res
apply(func, final=False)

Apply a passed function.

Parameters:

Name Type Description Default
func

Function to apply.

required
final

Whether this is the final operation in the chain.

False

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
def apply(self, func, final=False):
    """
    Apply a passed function.

    Parameters
    ----------
    func: function
        Function to apply.
    final: bool
        Whether this is the final operation in the chain.

    Returns
    -------
    GroupAwareOperation
    """
    return self.chain(add_op=func, final=final)
chain(add_op=None, final=False)

Chain another operation to this one.

Parameters:

Name Type Description Default
add_op

Operation to add. Can be a string (e.g. "min", "max", "sum") or a function.

None
final

Whether this is the final operation in the chain.

False

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
def chain(self, add_op=None, final=False):
    """
    Chain another operation to this one.

    Parameters
    ----------
    add_op: str or function
        Operation to add. Can be a string (e.g. "min", "max", "sum") or a function.
    final: bool
        Whether this is the final operation in the chain.

    Returns
    -------
    GroupAwareOperation
    """
    if self.final:
        raise ValueError("Cannot chain any additional operation.")
    c = copy.copy(self)
    c.final = final
    c.opfuncs_custom = self.opfuncs_custom
    if add_op is not None:
        if isinstance(add_op, str):
            c.ops.append(add_op)
        elif callable(add_op):
            name = "custom" + str(len(self.opfuncs_custom) + 1)
            c.opfuncs_custom[name] = add_op
            c.ops.append(name)
        else:
            raise ValueError("Unknown operation of type '%s'" % str(type(add_op)))
    return c
evaluate(nmax=None, idxlist=None, compute=True)

Evaluate the operation.

Parameters:

Name Type Description Default
nmax

Maximum number of halos to process.

None
idxlist

List of halo indices to process. If not provided, (and nmax not set) all halos are processed.

None
compute

Whether to compute the result immediately or return a dask object to compute later.

True
Source code in src/scida/customs/arepo/dataset.py
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def evaluate(self, nmax=None, idxlist=None, compute=True):
    """
    Evaluate the operation.

    Parameters
    ----------
    nmax: Optional[int]
        Maximum number of halos to process.
    idxlist: Optional[np.ndarray]
        List of halo indices to process. If not provided, (and nmax not set) all halos are processed.
    compute: bool
        Whether to compute the result immediately or return a dask object to compute later.

    Returns
    -------

    """
    # TODO: figure out return type
    # final operations: those that can only be at end of chain
    # intermediate operations: those that can only be prior to end of chain
    funcdict = dict()
    funcdict.update(**self.opfuncs)
    funcdict.update(**self.opfuncs_custom)

    func = ChainOps(*[funcdict[k] for k in self.ops])

    fieldnames = list(self.arrs.keys())
    if self.inputfields is None:
        opname = self.ops[0]
        if opname.startswith("custom"):
            dfltkwargs = get_kwargs(self.opfuncs_custom[opname])
            fieldnames = dfltkwargs.get("fieldnames", None)
            if isinstance(fieldnames, str):
                fieldnames = [fieldnames]
            if fieldnames is None:
                raise ValueError(
                    "Either pass fields to grouped(fields=...) "
                    "or specify fieldnames=... in applied func."
                )
        else:
            raise ValueError(
                "Specify field to operate on in operation or grouped()."
            )

    res = map_group_operation(
        func,
        self.offsets,
        self.lengths,
        self.arrs,
        fieldnames=fieldnames,
        nmax=nmax,
        idxlist=idxlist,
    )
    if compute:
        res = res.compute()
    return res
half()

Half the number of particles in each group member. For testing purposes.

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
861
862
863
864
865
866
867
868
869
def half(self):
    """
    Half the number of particles in each group member. For testing purposes.

    Returns
    -------
    GroupAwareOperation
    """
    return self.chain(add_op="half", final=False)
max(field=None)

Get the maximum value for each group member.

Source code in src/scida/customs/arepo/dataset.py
841
842
843
844
845
846
847
848
849
def max(self, field=None):
    """
    Get the maximum value for each group member.
    """
    if field is not None:
        if self.inputfields is not None:
            raise ValueError("Cannot change input field anymore.")
        self.inputfields = [field]
    return self.chain(add_op="max", final=True)
min(field=None)

Get the minimum value for each group member.

Source code in src/scida/customs/arepo/dataset.py
831
832
833
834
835
836
837
838
839
def min(self, field=None):
    """
    Get the minimum value for each group member.
    """
    if field is not None:
        if self.inputfields is not None:
            raise ValueError("Cannot change input field anymore.")
        self.inputfields = [field]
    return self.chain(add_op="min", final=True)
sum(field=None)

Sum the values for each group member.

Source code in src/scida/customs/arepo/dataset.py
851
852
853
854
855
856
857
858
859
def sum(self, field=None):
    """
    Sum the values for each group member.
    """
    if field is not None:
        if self.inputfields is not None:
            raise ValueError("Cannot change input field anymore.")
        self.inputfields = [field]
    return self.chain(add_op="sum", final=True)
Temperature(arrs, ureg=None, **kwargs)

Compute gas temperature given (ElectronAbundance,InternalEnergy) in [K].

Source code in src/scida/customs/arepo/dataset.py
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
@fielddefs.register_field("PartType0")
def Temperature(arrs, ureg=None, **kwargs):
    """Compute gas temperature given (ElectronAbundance,InternalEnergy) in [K]."""
    xh = 0.76
    gamma = 5.0 / 3.0

    m_p = 1.672622e-24  # proton mass [g]
    k_B = 1.380650e-16  # boltzmann constant [erg/K]

    UnitEnergy_over_UnitMass = (
        1e10  # standard unit system (TODO: can obtain from snapshot)
    )
    f = UnitEnergy_over_UnitMass
    if ureg is not None:
        f = 1.0
        m_p = m_p * ureg.g
        k_B = k_B * ureg.erg / ureg.K

    xe = arrs["ElectronAbundance"]
    u_internal = arrs["InternalEnergy"]

    mu = 4 / (1 + 3 * xh + 4 * xh * xe) * m_p
    temp = f * (gamma - 1.0) * u_internal / k_B * mu

    return temp
compute_haloindex(gidx, halocelloffsets, *args, index_unbound=None)

Computes the halo index for each particle with dask.

Source code in src/scida/customs/arepo/dataset.py
1061
1062
1063
1064
1065
1066
1067
1068
1069
def compute_haloindex(gidx, halocelloffsets, *args, index_unbound=None):
    """Computes the halo index for each particle with dask."""
    return da.map_blocks(
        get_hidx_daskwrap,
        gidx,
        halocelloffsets,
        index_unbound=index_unbound,
        meta=np.array((), dtype=np.int64),
    )
compute_haloquantity(gidx, halocelloffsets, hvals, *args)

Computes a halo quantity for each particle with dask.

Source code in src/scida/customs/arepo/dataset.py
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
def compute_haloquantity(gidx, halocelloffsets, hvals, *args):
    """Computes a halo quantity for each particle with dask."""
    units = None
    if hasattr(hvals, "units"):
        units = hvals.units
    res = map_blocks(
        get_haloquantity_daskwrap,
        gidx,
        halocelloffsets,
        hvals,
        meta=np.array((), dtype=hvals.dtype),
        output_units=units,
    )
    return res
compute_localsubhaloindex(gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None)

Compute the local subhalo index for each particle with dask. The local subhalo index is the index of the subhalo within each halo, starting at 0 for the central subhalo.

Parameters:

Name Type Description Default
gidx
required
halocelloffsets
required
shnumber
required
shcounts
required
shcellcounts
required
index_unbound
None
Source code in src/scida/customs/arepo/dataset.py
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
def compute_localsubhaloindex(
    gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None
) -> da.Array:
    """
    Compute the local subhalo index for each particle with dask.
    The local subhalo index is the index of the subhalo within each halo,
    starting at 0 for the central subhalo.

    Parameters
    ----------
    gidx
    halocelloffsets
    shnumber
    shcounts
    shcellcounts
    index_unbound

    Returns
    -------

    """
    res = da.map_blocks(
        get_local_shidx_daskwrap,
        gidx,
        halocelloffsets,
        shnumber,
        shcounts,
        shcellcounts,
        index_unbound=index_unbound,
        meta=np.array((), dtype=np.int64),
    )
    return res
get_hidx(gidx_start, gidx_count, celloffsets, index_unbound=None)

Get halo index of a given cell

Parameters:

Name Type Description Default
gidx_start

The first unique integer ID for the first particle

required
gidx_count

The amount of halo indices we are querying after "gidx_start"

required
celloffsets array

An array holding the starting cell offset for each halo. Needs to include the offset after the last halo. The required shape is thus (Nhalo+1,).

required
index_unbound integer

The index to use for unbound particles. If None, the maximum integer value of the dtype is used.

None
Source code in src/scida/customs/arepo/dataset.py
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
@jit(nopython=True)
def get_hidx(gidx_start, gidx_count, celloffsets, index_unbound=None):
    """Get halo index of a given cell

    Parameters
    ----------
    gidx_start: integer
        The first unique integer ID for the first particle
    gidx_count: integer
        The amount of halo indices we are querying after "gidx_start"
    celloffsets : array
        An array holding the starting cell offset for each halo. Needs to include the
        offset after the last halo. The required shape is thus (Nhalo+1,).
    index_unbound : integer, optional
        The index to use for unbound particles. If None, the maximum integer value
        of the dtype is used.
    """
    dtype = np.int64
    if index_unbound is None:
        index_unbound = np.iinfo(dtype).max
    res = index_unbound * np.ones(gidx_count, dtype=dtype)
    # find initial celloffset
    hidx_idx = np.searchsorted(celloffsets, gidx_start, side="right") - 1
    if hidx_idx + 1 >= celloffsets.shape[0]:
        # we are done. Already out of scope of lookup => all unbound gas.
        return res
    celloffset = celloffsets[hidx_idx + 1]
    endid = celloffset - gidx_start
    startid = 0

    # Now iterate through list.
    while startid < gidx_count:
        res[startid:endid] = hidx_idx
        hidx_idx += 1
        startid = endid
        if hidx_idx >= celloffsets.shape[0] - 1:
            break
        count = celloffsets[hidx_idx + 1] - celloffsets[hidx_idx]
        endid = startid + count
    return res
get_localshidx(gidx_start, gidx_count, celloffsets, shnumber, shcounts, shcellcounts, index_unbound=None)

Get the local subhalo index for each particle. This is the subhalo index within each halo group. Particles belonging to the central galaxies will have index 0, particles belonging to the first satellite will have index 1, etc.

Parameters:

Name Type Description Default
gidx_start int
required
gidx_count int
required
celloffsets NDArray[int64]
required
shnumber
required
shcounts
required
shcellcounts
required
index_unbound

The index to use for unbound particles. If None, the maximum integer value of the dtype is used.

None

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
@jit(nopython=True)
def get_localshidx(
    gidx_start: int,
    gidx_count: int,
    celloffsets: NDArray[np.int64],
    shnumber,
    shcounts,
    shcellcounts,
    index_unbound=None,
):
    """
    Get the local subhalo index for each particle. This is the subhalo index within each
    halo group. Particles belonging to the central galaxies will have index 0, particles
    belonging to the first satellite will have index 1, etc.
    Parameters
    ----------
    gidx_start
    gidx_count
    celloffsets
    shnumber
    shcounts
    shcellcounts
    index_unbound: integer, optional
        The index to use for unbound particles. If None, the maximum integer value
        of the dtype is used.

    Returns
    -------
    np.ndarray
    """
    dtype = np.int32
    if index_unbound is None:
        index_unbound = np.iinfo(dtype).max
    res = index_unbound * np.ones(gidx_count, dtype=dtype)  # fuzz has negative index.

    # find initial Group we are in
    hidx_start_idx = np.searchsorted(celloffsets, gidx_start, side="right") - 1
    if hidx_start_idx + 1 >= celloffsets.shape[0]:
        # we are done. Already out of scope of lookup => all unbound gas.
        return res
    celloffset = celloffsets[hidx_start_idx + 1]
    endid = celloffset - gidx_start
    startid = 0

    # find initial subhalo we are in
    hidx = hidx_start_idx
    shcumsum = np.zeros(shcounts[hidx] + 1, dtype=np.int64)
    shcumsum[1:] = np.cumsum(
        shcellcounts[shnumber[hidx] : shnumber[hidx] + shcounts[hidx]]
    )  # collect halo's subhalo offsets
    shcumsum += celloffsets[hidx_start_idx]
    sidx_start_idx: int = int(np.searchsorted(shcumsum, gidx_start, side="right") - 1)
    if sidx_start_idx < shcounts[hidx]:
        endid = shcumsum[sidx_start_idx + 1] - gidx_start

    # Now iterate through list.
    cont = True
    while cont and (startid < gidx_count):
        res[startid:endid] = (
            sidx_start_idx if sidx_start_idx + 1 < shcumsum.shape[0] else -1
        )
        sidx_start_idx += 1
        if sidx_start_idx < shcounts[hidx_start_idx]:
            # we prepare to fill the next available subhalo for current halo
            count = shcumsum[sidx_start_idx + 1] - shcumsum[sidx_start_idx]
            startid = endid
        else:
            # we need to find the next halo to start filling its subhalos
            dbgcount = 0
            while dbgcount < 100:  # find next halo with >0 subhalos
                hidx_start_idx += 1
                if hidx_start_idx >= shcounts.shape[0]:
                    cont = False
                    break
                if shcounts[hidx_start_idx] > 0:
                    break
                dbgcount += 1
            hidx = hidx_start_idx
            if hidx_start_idx >= celloffsets.shape[0] - 1:
                startid = gidx_count
            else:
                count = celloffsets[hidx_start_idx + 1] - celloffsets[hidx_start_idx]
                if hidx < shcounts.shape[0]:
                    shcumsum = np.zeros(shcounts[hidx] + 1, dtype=np.int64)
                    shcumsum[1:] = np.cumsum(
                        shcellcounts[shnumber[hidx] : shnumber[hidx] + shcounts[hidx]]
                    )
                    shcumsum += celloffsets[hidx_start_idx]
                    sidx_start_idx = 0
                    if sidx_start_idx < shcounts[hidx]:
                        count = shcumsum[sidx_start_idx + 1] - shcumsum[sidx_start_idx]
                    startid = celloffsets[hidx_start_idx] - gidx_start
        endid = startid + count
    return res
get_shcounts_shcells(SubhaloGrNr, hlength)

Returns the id of the first subhalo and count of subhalos per halo.

Parameters:

Name Type Description Default
SubhaloGrNr

The group identifier that each subhalo belongs to respectively

required
hlength

The number of halos in the snapshot

required

Returns:

Name Type Description
shcounts ndarray

The number of subhalos per halo

shnumber ndarray

The index of the first subhalo per halo

Source code in src/scida/customs/arepo/dataset.py
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
@jit(nopython=True)
def get_shcounts_shcells(SubhaloGrNr, hlength):
    """
    Returns the id of the first subhalo and count of subhalos per halo.

    Parameters
    ----------
    SubhaloGrNr: np.ndarray
        The group identifier that each subhalo belongs to respectively
    hlength: int
        The number of halos in the snapshot

    Returns
    -------
    shcounts: np.ndarray
        The number of subhalos per halo
    shnumber: np.ndarray
        The index of the first subhalo per halo
    """
    shcounts = np.zeros(hlength, dtype=np.int32)  # number of subhalos per halo
    shnumber = np.zeros(hlength, dtype=np.int32)  # index of first subhalo per halo
    i = 0
    hid_old = 0
    while i < SubhaloGrNr.shape[0]:
        hid = SubhaloGrNr[i]
        if hid == hid_old:
            shcounts[hid] += 1
        else:
            shnumber[hid] = i
            shcounts[hid] += 1
            hid_old = hid
        i += 1
    return shcounts, shnumber
map_group_operation(func, offsets, lengths, arrdict, cpucost_halo=10000.0, nchunks_min=None, chunksize_bytes=None, entry_nbytes_in=4, fieldnames=None, nmax=None, idxlist=None)

Map a function to all halos in a halo catalog.

Parameters:

Name Type Description Default
idxlist Optional[ndarray]

Only process the halos with these indices.

None
nmax Optional[int]

Only process the first nmax halos.

None
func
required
offsets

Offset of each group in the particle catalog.

required
lengths

Number of particles per halo.

required
arrdict
required
cpucost_halo
10000.0
nchunks_min Optional[int]

Lower bound on the number of halos per chunk.

None
chunksize_bytes Optional[int]
None
entry_nbytes_in Optional[int]
4
fieldnames Optional[List[str]]
None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
def map_group_operation(
    func,
    offsets,
    lengths,
    arrdict,
    cpucost_halo=1e4,
    nchunks_min: Optional[int] = None,
    chunksize_bytes: Optional[int] = None,
    entry_nbytes_in: Optional[int] = 4,
    fieldnames: Optional[List[str]] = None,
    nmax: Optional[int] = None,
    idxlist: Optional[np.ndarray] = None,
) -> da.Array:
    """
    Map a function to all halos in a halo catalog.
    Parameters
    ----------
    idxlist: Optional[np.ndarray]
        Only process the halos with these indices.
    nmax: Optional[int]
        Only process the first nmax halos.
    func
    offsets: np.ndarray
        Offset of each group in the particle catalog.
    lengths: np.ndarray
        Number of particles per halo.
    arrdict
    cpucost_halo
    nchunks_min: Optional[int]
        Lower bound on the number of halos per chunk.
    chunksize_bytes
    entry_nbytes_in
    fieldnames

    Returns
    -------
    None
    """
    if isinstance(func, ChainOps):
        dfltkwargs = func.kwargs
    else:
        dfltkwargs = get_kwargs(func)
    if fieldnames is None:
        fieldnames = dfltkwargs.get("fieldnames", None)
    if fieldnames is None:
        fieldnames = get_args(func)
    units = dfltkwargs.get("units", None)
    shape = dfltkwargs.get("shape", None)
    dtype = dfltkwargs.get("dtype", "float64")
    fill_value = dfltkwargs.get("fill_value", 0)

    if idxlist is not None and nmax is not None:
        raise ValueError("Cannot specify both idxlist and nmax.")

    lengths_all = lengths
    offsets_all = offsets
    if len(lengths) == len(offsets):
        # the offsets array here is one longer here, holding the total number of particles in the last halo.
        offsets_all = np.concatenate([offsets_all, [offsets_all[-1] + lengths[-1]]])

    if nmax is not None:
        lengths = lengths[:nmax]
        offsets = offsets[:nmax]

    if idxlist is not None:
        # make sure idxlist is sorted and unique
        if not np.all(np.diff(idxlist) > 0):
            raise ValueError("idxlist must be sorted and unique.")
        # make sure idxlist is within range
        if np.min(idxlist) < 0 or np.max(idxlist) >= lengths.shape[0]:
            raise ValueError(
                "idxlist elements must be in [%i, %i), but covers range [%i, %i]."
                % (0, lengths.shape[0], np.min(idxlist), np.max(idxlist))
            )
        offsets = offsets[idxlist]
        lengths = lengths[idxlist]

    if len(lengths) == len(offsets):
        # the offsets array here is one longer here, holding the total number of particles in the last halo.
        offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]])

    # shape/units inference
    infer_shape = shape is None or (isinstance(shape, str) and shape == "auto")
    infer_units = units is None
    infer = infer_shape or infer_units
    if infer:
        # attempt to determine shape.
        if infer_shape:
            log.debug(
                "No shape specified. Attempting to determine shape of func output."
            )
        if infer_units:
            log.debug(
                "No units specified. Attempting to determine units of func output."
            )
        arrs = [arrdict[f][:1].compute() for f in fieldnames]
        # remove units if present
        # arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
        # arrs = [arr.magnitude for arr in arrs]
        dummyres = None
        try:
            dummyres = func(*arrs)
        except Exception as e:  # noqa
            log.warning("Exception during shape/unit inference: %s." % str(e))
        if dummyres is not None:
            if infer_units and hasattr(dummyres, "units"):
                units = dummyres.units
            log.debug("Shape inference: %s." % str(shape))
        if infer_units and dummyres is None:
            units_present = any([hasattr(arr, "units") for arr in arrs])
            if units_present:
                log.warning("Exception during unit inference. Assuming no units.")
        if dummyres is None and infer_shape:
            # due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail.
            # we can still attempt to infer shape by removing units prior to calling func.
            arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
            try:
                dummyres = func(*arrs)
            except Exception as e:  # noqa
                # no more logging needed here
                pass
        if dummyres is not None and infer_shape:
            if np.isscalar(dummyres):
                shape = (1,)
            else:
                shape = dummyres.shape
        if infer_shape and dummyres is None and shape is None:
            log.warning("Exception during shape inference. Using shape (1,).")
            shape = ()
    # unit inference

    # Determine chunkedges automatically
    # TODO: very messy and inefficient routine. improve some time.
    # TODO: Set entry_bytes_out
    nbytes_dtype_out = 4  # TODO: hardcode 4 byte output dtype as estimate for now
    entry_nbytes_out = nbytes_dtype_out * np.product(shape)

    # list_chunkedges refers to bounds of index intervals to be processed together
    # if idxlist is specified, then these indices do not have to refer to group indices.
    # if idxlist is given, we enforce that particle data is contiguous
    # by putting each idx from idxlist into its own chunk.
    # in the future, we should optimize this
    if idxlist is not None:
        list_chunkedges = [[idx, idx + 1] for idx in np.arange(len(idxlist))]
    else:
        list_chunkedges = map_group_operation_get_chunkedges(
            lengths,
            entry_nbytes_in,
            entry_nbytes_out,
            cpucost_halo=cpucost_halo,
            nchunks_min=nchunks_min,
            chunksize_bytes=chunksize_bytes,
        )

    minentry = offsets[0]
    maxentry = offsets[-1]  # the last particle that needs to be processed

    # chunks specify the number of groups in each chunk
    chunks = [tuple(np.diff(list_chunkedges, axis=1).flatten())]
    # need to add chunk information for additional output axes if needed
    new_axis = None
    if isinstance(shape, tuple) and shape != (1,):
        chunks += [(s,) for s in shape]
        new_axis = np.arange(1, len(shape) + 1).tolist()

    # slcoffsets = [offsets[chunkedge[0]] for chunkedge in list_chunkedges]
    # the actual length of relevant data in each chunk
    slclengths = [
        offsets[chunkedge[1]] - offsets[chunkedge[0]] for chunkedge in list_chunkedges
    ]
    if idxlist is not None:
        # the chunk length to be fed into map_blocks
        tmplist = np.concatenate([idxlist, [len(lengths_all)]])
        slclengths_map = [
            offsets_all[tmplist[chunkedge[1]]] - offsets_all[tmplist[chunkedge[0]]]
            for chunkedge in list_chunkedges
        ]
        slcoffsets_map = [
            offsets_all[tmplist[chunkedge[0]]] for chunkedge in list_chunkedges
        ]
        slclengths_map[0] = slcoffsets_map[0]
        slcoffsets_map[0] = 0
    else:
        slclengths_map = slclengths

    slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges]
    offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs]
    lengths_in_chunks = [lengths[slc] for slc in slcs]
    d_oic = delayed(offsets_in_chunks)
    d_hic = delayed(lengths_in_chunks)

    arrs = [arrdict[f][minentry:maxentry] for f in fieldnames]
    for i, arr in enumerate(arrs):
        arrchunks = ((tuple(slclengths)),)
        if len(arr.shape) > 1:
            arrchunks = arrchunks + (arr.shape[1:],)
        arrs[i] = arr.rechunk(chunks=arrchunks)
    arrdims = np.array([len(arr.shape) for arr in arrs])

    assert np.all(arrdims == arrdims[0])  # Cannot handle different input dims for now

    drop_axis = []
    if arrdims[0] > 1:
        drop_axis = np.arange(1, arrdims[0])

    if dtype is None:
        raise ValueError(
            "dtype must be specified, dask will not be able to automatically determine this here."
        )

    calc = map_blocks(
        wrap_func_scalar,
        func,
        d_oic,
        d_hic,
        *arrs,
        dtype=dtype,
        chunks=chunks,
        new_axis=new_axis,
        drop_axis=drop_axis,
        func_output_shape=shape,
        func_output_dtype=dtype,
        fill_value=fill_value,
        output_units=units,
    )

    return calc
map_group_operation_get_chunkedges(lengths, entry_nbytes_in, entry_nbytes_out, cpucost_halo=1.0, nchunks_min=None, chunksize_bytes=None)

Compute the chunking of a halo operation.

Parameters:

Name Type Description Default
lengths

The number of particles per halo.

required
entry_nbytes_in
required
entry_nbytes_out
required
cpucost_halo
1.0
nchunks_min
None
chunksize_bytes
None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
def map_group_operation_get_chunkedges(
    lengths,
    entry_nbytes_in,
    entry_nbytes_out,
    cpucost_halo=1.0,
    nchunks_min=None,
    chunksize_bytes=None,
):
    """
    Compute the chunking of a halo operation.

    Parameters
    ----------
    lengths: np.ndarray
        The number of particles per halo.
    entry_nbytes_in
    entry_nbytes_out
    cpucost_halo
    nchunks_min
    chunksize_bytes

    Returns
    -------
    None
    """
    cpucost_particle = 1.0  # we only care about ratio, so keep particle cost fixed.
    cost = cpucost_particle * lengths + cpucost_halo
    sumcost = cost.cumsum()

    # let's allow a maximal chunksize of 16 times the dask default setting for an individual array [here: multiple]
    if chunksize_bytes is None:
        chunksize_bytes = 16 * parse_humansize(dask.config.get("array.chunk-size"))
    cost_memory = entry_nbytes_in * lengths + entry_nbytes_out

    if not np.max(cost_memory) < chunksize_bytes:
        raise ValueError(
            "Some halo requires more memory than allowed (%i allowed, %i requested). Consider overriding "
            "chunksize_bytes." % (chunksize_bytes, np.max(cost_memory))
        )

    nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
    nchunks = int(np.ceil(1.3 * nchunks))  # fudge factor
    if nchunks_min is not None:
        nchunks = max(nchunks_min, nchunks)
    targetcost = sumcost[-1] / nchunks  # chunk target cost = total cost / nchunks

    arr = np.diff(sumcost % targetcost)  # find whenever exceeding modulo target cost
    idx = [0] + list(np.where(arr < 0)[0] + 1)
    if idx[-1] != sumcost.shape[0]:
        idx.append(sumcost.shape[0])
    list_chunkedges = []
    for i in range(len(idx) - 1):
        list_chunkedges.append([idx[i], idx[i + 1]])

    list_chunkedges = np.asarray(
        memorycost_limiter(cost_memory, cost, list_chunkedges, chunksize_bytes)
    )

    # make sure we did not lose any halos.
    assert np.all(
        ~(list_chunkedges.flatten()[2:-1:2] - list_chunkedges.flatten()[1:-1:2]).astype(
            bool
        )
    )
    return list_chunkedges
memorycost_limiter(cost_memory, cost_cpu, list_chunkedges, cost_memory_max)

If a chunk too memory expensive, split into equal cpu expense operations.

Source code in src/scida/customs/arepo/dataset.py
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
def memorycost_limiter(cost_memory, cost_cpu, list_chunkedges, cost_memory_max):
    """If a chunk too memory expensive, split into equal cpu expense operations."""
    list_chunkedges_new = []
    for chunkedges in list_chunkedges:
        slc = slice(*chunkedges)
        totcost_mem = np.sum(cost_memory[slc])
        list_chunkedges_new.append(chunkedges)
        if totcost_mem > cost_memory_max:
            sumcost = cost_cpu[slc].cumsum()
            sumcost /= sumcost[-1]
            idx = slc.start + np.argmin(np.abs(sumcost - 0.5))
            if idx == chunkedges[0]:
                idx += 1
            elif idx == chunkedges[-1]:
                idx -= 1
            chunkedges1 = [chunkedges[0], idx]
            chunkedges2 = [idx, chunkedges[1]]
            if idx == chunkedges[0] or idx == chunkedges[1]:
                raise ValueError("This should not happen.")
            list_chunkedges_new.pop()
            list_chunkedges_new += memorycost_limiter(
                cost_memory, cost_cpu, [chunkedges1], cost_memory_max
            )
            list_chunkedges_new += memorycost_limiter(
                cost_memory, cost_cpu, [chunkedges2], cost_memory_max
            )
    return list_chunkedges_new
wrap_func_scalar(func, offsets_in_chunks, lengths_in_chunks, *arrs, block_info=None, block_id=None, func_output_shape=(1), func_output_dtype='float64', fill_value=0)

Wrapper for applying a function to each halo in the passed chunk.

Parameters:

Name Type Description Default
func
required
offsets_in_chunks
required
lengths_in_chunks
required
arrs
()
block_info
None
block_id
None
func_output_shape
(1)
func_output_dtype
'float64'
fill_value
0
Source code in src/scida/customs/arepo/dataset.py
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
def wrap_func_scalar(
    func,
    offsets_in_chunks,
    lengths_in_chunks,
    *arrs,
    block_info=None,
    block_id=None,
    func_output_shape=(1,),
    func_output_dtype="float64",
    fill_value=0,
):
    """
    Wrapper for applying a function to each halo in the passed chunk.
    Parameters
    ----------
    func
    offsets_in_chunks
    lengths_in_chunks
    arrs
    block_info
    block_id
    func_output_shape
    func_output_dtype
    fill_value

    Returns
    -------

    """
    offsets = offsets_in_chunks[block_id[0]]
    lengths = lengths_in_chunks[block_id[0]]

    res = []
    for i, length in enumerate(lengths):
        o = offsets[i]
        if length == 0:
            res.append(fill_value * np.ones(func_output_shape, dtype=func_output_dtype))
            if func_output_shape == (1,):
                res[-1] = res[-1].item()
            continue
        arrchunks = [arr[o : o + length] for arr in arrs]
        res.append(func(*arrchunks))
    return np.array(res)

helpers

Helper functions for arepo snapshots/simulations.

grp_type_str(gtype)

Mapping between common group names and numeric group types.

Source code in src/scida/customs/arepo/helpers.py
 4
 5
 6
 7
 8
 9
10
def grp_type_str(gtype):
    """Mapping between common group names and numeric group types."""
    if str(gtype).lower() in ["group", "groups", "halo", "halos"]:
        return "halo"
    if str(gtype).lower() in ["subgroup", "subgroups", "subhalo", "subhalos"]:
        return "subhalo"
    raise ValueError("Unknown group type: %s" % gtype)
part_type_num(ptype)

Mapping between common particle names and numeric particle types.

Source code in src/scida/customs/arepo/helpers.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def part_type_num(ptype):
    """Mapping between common particle names and numeric particle types."""
    ptype = str(ptype).replace("PartType", "")
    if ptype.isdigit():
        return int(ptype)

    if str(ptype).lower() in ["gas", "cells"]:
        return 0
    if str(ptype).lower() in ["dm", "darkmatter"]:
        return 1
    if str(ptype).lower() in ["dmlowres"]:
        return 2  # only zoom simulations, not present in full periodic boxes
    if str(ptype).lower() in ["tracer", "tracers", "tracermc", "trmc"]:
        return 3
    if str(ptype).lower() in ["star", "stars", "stellar"]:
        return 4  # only those with GFM_StellarFormationTime>0
    if str(ptype).lower() in ["wind"]:
        return 4  # only those with GFM_StellarFormationTime<0
    if str(ptype).lower() in ["bh", "bhs", "blackhole", "blackholes", "black"]:
        return 5
    if str(ptype).lower() in ["all"]:
        return -1

selector

Selector for ArepoSnapshot

ArepoSelector

Bases: Selector

Selector for ArepoSnapshot. Can select for haloID, subhaloID, and unbound particles.

Source code in src/scida/customs/arepo/selector.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class ArepoSelector(Selector):
    """Selector for ArepoSnapshot.
    Can select for haloID, subhaloID, and unbound particles."""

    def __init__(self) -> None:
        """
        Initialize the selector.
        """
        super().__init__()
        self.keys = ["haloID", "subhaloID", "unbound"]

    def prepare(self, *args, **kwargs) -> None:
        if all([kwargs.get(k, None) is None for k in self.keys]):
            return  # no specific selection, thus just return
        snap: ArepoSnapshot = args[0]
        halo_id = kwargs.get("haloID", None)
        subhalo_id = kwargs.get("subhaloID", None)
        unbound = kwargs.get("unbound", None)

        if halo_id is not None and subhalo_id is not None:
            raise ValueError("Cannot select for haloID and subhaloID at the same time.")

        if unbound is True and (halo_id is not None or subhalo_id is not None):
            raise ValueError(
                "Cannot select haloID/subhaloID and unbound particles at the same time."
            )

        if snap.catalog is None:
            raise ValueError("Cannot select for haloID without catalog loaded.")

        # select for halo
        idx = subhalo_id if subhalo_id is not None else halo_id
        objtype = "subhalo" if subhalo_id is not None else "halo"
        if idx is not None:
            self.select_group(snap, idx, objtype=objtype)
        elif unbound is True:
            self.select_unbound(snap)

    def select_unbound(self, snap):
        """
        Select unbound particles.

        Parameters
        ----------
        snap: ArepoSnapshot

        Returns
        -------
        None
        """
        lengths = self.data_backup["Group"]["GroupLenType"][-1, :].compute()
        offsets = self.data_backup["Group"]["GroupOffsetsType"][-1, :].compute()
        # for unbound gas, we start after the last halo particles
        offsets = offsets + lengths
        for p in self.data_backup:
            splt = p.split("PartType")
            if len(splt) == 1:
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v
            else:
                pnum = int(splt[1])
                offset = offsets[pnum]
                if hasattr(offset, "magnitude"):  # hack for issue 59
                    offset = offset.magnitude
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v[offset:-1]
        snap.data = self.data

    def select_group(self, snap, idx, objtype="Group"):
        """
        Select particles for given group/subhalo index.

        Parameters
        ----------
        snap: ArepoSnapshot
        idx: int
        objtype: str

        Returns
        -------
        None
        """
        # TODO: test whether works for multiple groups via idx list
        objtype = grp_type_str(objtype)
        if objtype == "halo":
            lengths = self.data_backup["Group"]["GroupLenType"][idx, :].compute()
            offsets = self.data_backup["Group"]["GroupOffsetsType"][idx, :].compute()
        elif objtype == "subhalo":
            lengths = {i: snap.get_subhalolengths(i)[idx] for i in range(6)}
            offsets = {i: snap.get_subhalooffsets(i)[idx] for i in range(6)}
        else:
            raise ValueError("Unknown object type: %s" % objtype)

        for p in self.data_backup:
            splt = p.split("PartType")
            if len(splt) == 1:
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v
            else:
                pnum = int(splt[1])
                offset = offsets[pnum]
                length = lengths[pnum]
                if hasattr(offset, "magnitude"):  # hack for issue 59
                    offset = offset.magnitude
                if hasattr(length, "magnitude"):
                    length = length.magnitude
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v[offset : offset + length]
        snap.data = self.data
__init__()

Initialize the selector.

Source code in src/scida/customs/arepo/selector.py
18
19
20
21
22
23
def __init__(self) -> None:
    """
    Initialize the selector.
    """
    super().__init__()
    self.keys = ["haloID", "subhaloID", "unbound"]
select_group(snap, idx, objtype='Group')

Select particles for given group/subhalo index.

Parameters:

Name Type Description Default
snap
required
idx
required
objtype
'Group'

Returns:

Type Description
None
Source code in src/scida/customs/arepo/selector.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def select_group(self, snap, idx, objtype="Group"):
    """
    Select particles for given group/subhalo index.

    Parameters
    ----------
    snap: ArepoSnapshot
    idx: int
    objtype: str

    Returns
    -------
    None
    """
    # TODO: test whether works for multiple groups via idx list
    objtype = grp_type_str(objtype)
    if objtype == "halo":
        lengths = self.data_backup["Group"]["GroupLenType"][idx, :].compute()
        offsets = self.data_backup["Group"]["GroupOffsetsType"][idx, :].compute()
    elif objtype == "subhalo":
        lengths = {i: snap.get_subhalolengths(i)[idx] for i in range(6)}
        offsets = {i: snap.get_subhalooffsets(i)[idx] for i in range(6)}
    else:
        raise ValueError("Unknown object type: %s" % objtype)

    for p in self.data_backup:
        splt = p.split("PartType")
        if len(splt) == 1:
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v
        else:
            pnum = int(splt[1])
            offset = offsets[pnum]
            length = lengths[pnum]
            if hasattr(offset, "magnitude"):  # hack for issue 59
                offset = offset.magnitude
            if hasattr(length, "magnitude"):
                length = length.magnitude
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v[offset : offset + length]
    snap.data = self.data
select_unbound(snap)

Select unbound particles.

Parameters:

Name Type Description Default
snap
required

Returns:

Type Description
None
Source code in src/scida/customs/arepo/selector.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def select_unbound(self, snap):
    """
    Select unbound particles.

    Parameters
    ----------
    snap: ArepoSnapshot

    Returns
    -------
    None
    """
    lengths = self.data_backup["Group"]["GroupLenType"][-1, :].compute()
    offsets = self.data_backup["Group"]["GroupOffsetsType"][-1, :].compute()
    # for unbound gas, we start after the last halo particles
    offsets = offsets + lengths
    for p in self.data_backup:
        splt = p.split("PartType")
        if len(splt) == 1:
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v
        else:
            pnum = int(splt[1])
            offset = offsets[pnum]
            if hasattr(offset, "magnitude"):  # hack for issue 59
                offset = offset.magnitude
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v[offset:-1]
    snap.data = self.data

series

Contains Series class for Arepo simulations.

ArepoSimulation

Bases: GadgetStyleSimulation

A series representing an Arepo simulation.

Source code in src/scida/customs/arepo/series.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class ArepoSimulation(GadgetStyleSimulation):
    """A series representing an Arepo simulation."""

    def __init__(self, path, lazy=True, async_caching=False, **interface_kwargs):
        """
        Initialize an ArepoSimulation object.

        Parameters
        ----------
        path: str
            Path to the simulation folder, should contain "output" folder.
        lazy: bool
            Whether to load data files lazily.
        interface_kwargs: dict
            Additional keyword arguments passed to the interface.
        """
        # choose parent folder as path if we are passed "output" dir
        p = pathlib.Path(path)
        if p.name == "output":
            path = str(p.parent)
        prefix_dict = dict(paths="snapdir", gpaths="group")
        arg_dict = dict(gpaths="catalog")
        super().__init__(
            path,
            prefix_dict=prefix_dict,
            arg_dict=arg_dict,
            lazy=lazy,
            **interface_kwargs
        )

    @classmethod
    def validate_path(cls, path, *args, **kwargs) -> CandidateStatus:
        """
        Validate a path as a candidate for this simulation class.

        Parameters
        ----------
        path: str
            Path to validate.
        args: list
            Additional positional arguments.
        kwargs:
            Additional keyword arguments.

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this simulation class.
        """
        valid = CandidateStatus.NO
        if not os.path.isdir(path):
            return CandidateStatus.NO
        fns = os.listdir(path)
        if "gizmo_parameters.txt" in fns:
            return CandidateStatus.NO
        sprefixs = ["snapdir", "snapshot"]
        opath = path
        if "output" in fns:
            opath = join(path, "output")
        folders = os.listdir(opath)
        folders = [f for f in folders if os.path.isdir(join(opath, f))]
        if any([f.startswith(k) for f in folders for k in sprefixs]):
            valid = CandidateStatus.MAYBE
        return valid
__init__(path, lazy=True, async_caching=False, **interface_kwargs)

Initialize an ArepoSimulation object.

Parameters:

Name Type Description Default
path

Path to the simulation folder, should contain "output" folder.

required
lazy

Whether to load data files lazily.

True
interface_kwargs

Additional keyword arguments passed to the interface.

{}
Source code in src/scida/customs/arepo/series.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(self, path, lazy=True, async_caching=False, **interface_kwargs):
    """
    Initialize an ArepoSimulation object.

    Parameters
    ----------
    path: str
        Path to the simulation folder, should contain "output" folder.
    lazy: bool
        Whether to load data files lazily.
    interface_kwargs: dict
        Additional keyword arguments passed to the interface.
    """
    # choose parent folder as path if we are passed "output" dir
    p = pathlib.Path(path)
    if p.name == "output":
        path = str(p.parent)
    prefix_dict = dict(paths="snapdir", gpaths="group")
    arg_dict = dict(gpaths="catalog")
    super().__init__(
        path,
        prefix_dict=prefix_dict,
        arg_dict=arg_dict,
        lazy=lazy,
        **interface_kwargs
    )
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for this simulation class.

Parameters:

Name Type Description Default
path

Path to validate.

required
args

Additional positional arguments.

()
kwargs

Additional keyword arguments.

{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this simulation class.

Source code in src/scida/customs/arepo/series.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@classmethod
def validate_path(cls, path, *args, **kwargs) -> CandidateStatus:
    """
    Validate a path as a candidate for this simulation class.

    Parameters
    ----------
    path: str
        Path to validate.
    args: list
        Additional positional arguments.
    kwargs:
        Additional keyword arguments.

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this simulation class.
    """
    valid = CandidateStatus.NO
    if not os.path.isdir(path):
        return CandidateStatus.NO
    fns = os.listdir(path)
    if "gizmo_parameters.txt" in fns:
        return CandidateStatus.NO
    sprefixs = ["snapdir", "snapshot"]
    opath = path
    if "output" in fns:
        opath = join(path, "output")
    folders = os.listdir(opath)
    folders = [f for f in folders if os.path.isdir(join(opath, f))]
    if any([f.startswith(k) for f in folders for k in sprefixs]):
        valid = CandidateStatus.MAYBE
    return valid

gadgetstyle

dataset

Defines the GadgetStyleSnapshot class, mostly used for deriving subclasses for related codes/simulations.

GadgetStyleSnapshot

Bases: Dataset

A dataset representing a Gadget-style snapshot.

Source code in src/scida/customs/gadgetstyle/dataset.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class GadgetStyleSnapshot(Dataset):
    """A dataset representing a Gadget-style snapshot."""

    def __init__(self, path, chunksize="auto", virtualcache=True, **kwargs) -> None:
        """We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow
        the common /PartType0, /PartType1 grouping scheme."""
        self.boxsize = np.nan
        super().__init__(path, chunksize=chunksize, virtualcache=virtualcache, **kwargs)

        defaultattributes = ["config", "header", "parameters"]
        for k in self._metadata_raw:
            name = k.strip("/").lower()
            if name in defaultattributes:
                self.__dict__[name] = self._metadata_raw[k]
                if "BoxSize" in self.__dict__[name]:
                    self.boxsize = self.__dict__[name]["BoxSize"]
                elif "Boxsize" in self.__dict__[name]:
                    self.boxsize = self.__dict__[name]["Boxsize"]

        sanity_check = kwargs.get("sanity_check", False)
        key_nparts = "NumPart_Total"
        key_nparts_hw = "NumPart_Total_HighWord"
        if sanity_check and key_nparts in self.header and key_nparts_hw in self.header:
            nparts = self.header[key_nparts_hw] * 2**32 + self.header[key_nparts]
            for i, n in enumerate(nparts):
                pkey = "PartType%i" % i
                if pkey in self.data:
                    pdata = self.data[pkey]
                    fkey = next(iter(pdata.keys()))
                    nparts_loaded = pdata[fkey].shape[0]
                    if nparts_loaded != n:
                        raise ValueError(
                            "Number of particles in header (%i) does not match number of particles loaded (%i) "
                            "for particle type %i" % (n, nparts_loaded, i)
                        )

    @classmethod
    def _get_fileprefix(cls, path: Union[str, os.PathLike], **kwargs) -> str:
        """
        Get the fileprefix used to identify files belonging to given dataset.
        Parameters
        ----------
        path: str, os.PathLike
            path to check
        kwargs

        Returns
        -------
        str
        """
        if os.path.isfile(path):
            return ""  # nothing to do, we have a single file, not a directory
        # order matters: groups will be taken before fof_subhalo, requires py>3.7 for dict order
        prfxs = ["groups", "fof_subhalo", "snap"]
        prfxs_prfx_sim = dict.fromkeys(prfxs)
        files = sorted(os.listdir(path))
        prfxs_lst = []
        for fn in files:
            s = re.search(r"^(\w*)_(\d*)", fn)
            if s is not None:
                prfxs_lst.append(s.group(1))
        prfxs_lst = [p for s in prfxs_prfx_sim for p in prfxs_lst if p.startswith(s)]
        prfxs = dict.fromkeys(prfxs_lst)
        prfxs = list(prfxs.keys())
        if len(prfxs) > 1:
            log.debug("We have more than one prefix avail: %s" % prfxs)
        elif len(prfxs) == 0:
            return ""
        if set(prfxs) == {"groups", "fof_subhalo_tab"}:
            return "groups"  # "groups" over "fof_subhalo_tab"
        return prfxs[0]

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, expect_grp=False, **kwargs
    ) -> CandidateStatus:
        """
        Check if path is valid for this interface.
        Parameters
        ----------
        path: str, os.PathLike
            path to check
        args
        kwargs

        Returns
        -------
        bool
        """
        path = str(path)
        possibly_valid = CandidateStatus.NO
        iszarr = path.rstrip("/").endswith(".zarr")
        if path.endswith(".hdf5") or iszarr:
            possibly_valid = CandidateStatus.MAYBE
        if os.path.isdir(path):
            files = os.listdir(path)
            sufxs = [f.split(".")[-1] for f in files]
            if not iszarr and len(set(sufxs)) > 1:
                possibly_valid = CandidateStatus.NO
            if sufxs[0] == "hdf5":
                possibly_valid = CandidateStatus.MAYBE
        if possibly_valid != CandidateStatus.NO:
            metadata_raw = load_metadata(path, **kwargs)
            # need some silly combination of attributes to be sure
            if all([k in metadata_raw for k in ["/Header"]]):
                # identifying snapshot or group catalog
                is_snap = all(
                    [
                        k in metadata_raw["/Header"]
                        for k in ["NumPart_ThisFile", "NumPart_Total"]
                    ]
                )
                is_grp = all(
                    [
                        k in metadata_raw["/Header"]
                        for k in ["Ngroups_ThisFile", "Ngroups_Total"]
                    ]
                )
                if is_grp:
                    return CandidateStatus.MAYBE
                if is_snap and not expect_grp:
                    return CandidateStatus.MAYBE
        return CandidateStatus.NO

    def register_field(self, parttype, name=None, description=""):
        """
        Register a field for a given particle type by returning through decorator.

        Parameters
        ----------
        parttype: Optional[Union[str, List[str]]]
            Particle type name to register with. If None, register for the base field container.
        name: Optional[str]
            Name of the field to register.
        description: Optional[str]
            Description of the field to register.

        Returns
        -------
        callable

        """
        res = self.data.register_field(parttype, name=name, description=description)
        return res

    def merge_data(
        self, secondobj, fieldname_suffix="", root_group: Optional[str] = None
    ):
        """
        Merge data from other snapshot into self.data.

        Parameters
        ----------
        secondobj: GadgetStyleSnapshot
        fieldname_suffix: str
        root_group: Optional[str]

        Returns
        -------
        None

        """
        data = self.data
        if root_group is not None:
            if root_group not in data._containers:
                data.add_container(root_group)
            data = self.data[root_group]
        for k in secondobj.data:
            key = k + fieldname_suffix
            if key not in data:
                data[key] = secondobj.data[k]
            else:
                log.debug("Not overwriting field '%s' during merge_data." % key)
            secondobj.data.fieldrecipes_kwargs["snap"] = self

    def merge_hints(self, secondobj):
        """
        Merge hints from other snapshot into self.hints.

        Parameters
        ----------
        secondobj: GadgetStyleSnapshot
            Other snapshot to merge hints from.

        Returns
        -------
        None
        """
        # merge hints from snap and catalog
        for h in secondobj.hints:
            if h not in self.hints:
                self.hints[h] = secondobj.hints[h]
            elif isinstance(self.hints[h], dict):
                # merge dicts
                for k in secondobj.hints[h]:
                    if k not in self.hints[h]:
                        self.hints[h][k] = secondobj.hints[h][k]
            else:
                pass  # nothing to do; we do not overwrite with catalog props

    @classmethod
    def _clean_metadata_from_raw(cls, rawmetadata):
        """
        Set metadata from raw metadata.
        """
        metadata = dict()
        if "/Header" in rawmetadata:
            header = rawmetadata["/Header"]
            if "Redshift" in header:
                metadata["redshift"] = float(header["Redshift"])
                metadata["z"] = metadata["redshift"]
            if "BoxSize" in header:
                # can be scalar or array
                metadata["boxsize"] = header["BoxSize"]
            if "Time" in header:
                metadata["time"] = float(header["Time"])
                metadata["t"] = metadata["time"]
        return metadata

    def _set_metadata(self):
        """
        Set metadata from header and config.
        """
        md = self._clean_metadata_from_raw(self._metadata_raw)
        self.metadata = md
__init__(path, chunksize='auto', virtualcache=True, **kwargs)

We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow the common /PartType0, /PartType1 grouping scheme.

Source code in src/scida/customs/gadgetstyle/dataset.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(self, path, chunksize="auto", virtualcache=True, **kwargs) -> None:
    """We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow
    the common /PartType0, /PartType1 grouping scheme."""
    self.boxsize = np.nan
    super().__init__(path, chunksize=chunksize, virtualcache=virtualcache, **kwargs)

    defaultattributes = ["config", "header", "parameters"]
    for k in self._metadata_raw:
        name = k.strip("/").lower()
        if name in defaultattributes:
            self.__dict__[name] = self._metadata_raw[k]
            if "BoxSize" in self.__dict__[name]:
                self.boxsize = self.__dict__[name]["BoxSize"]
            elif "Boxsize" in self.__dict__[name]:
                self.boxsize = self.__dict__[name]["Boxsize"]

    sanity_check = kwargs.get("sanity_check", False)
    key_nparts = "NumPart_Total"
    key_nparts_hw = "NumPart_Total_HighWord"
    if sanity_check and key_nparts in self.header and key_nparts_hw in self.header:
        nparts = self.header[key_nparts_hw] * 2**32 + self.header[key_nparts]
        for i, n in enumerate(nparts):
            pkey = "PartType%i" % i
            if pkey in self.data:
                pdata = self.data[pkey]
                fkey = next(iter(pdata.keys()))
                nparts_loaded = pdata[fkey].shape[0]
                if nparts_loaded != n:
                    raise ValueError(
                        "Number of particles in header (%i) does not match number of particles loaded (%i) "
                        "for particle type %i" % (n, nparts_loaded, i)
                    )
merge_data(secondobj, fieldname_suffix='', root_group=None)

Merge data from other snapshot into self.data.

Parameters:

Name Type Description Default
secondobj
required
fieldname_suffix
''
root_group Optional[str]
None

Returns:

Type Description
None
Source code in src/scida/customs/gadgetstyle/dataset.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def merge_data(
    self, secondobj, fieldname_suffix="", root_group: Optional[str] = None
):
    """
    Merge data from other snapshot into self.data.

    Parameters
    ----------
    secondobj: GadgetStyleSnapshot
    fieldname_suffix: str
    root_group: Optional[str]

    Returns
    -------
    None

    """
    data = self.data
    if root_group is not None:
        if root_group not in data._containers:
            data.add_container(root_group)
        data = self.data[root_group]
    for k in secondobj.data:
        key = k + fieldname_suffix
        if key not in data:
            data[key] = secondobj.data[k]
        else:
            log.debug("Not overwriting field '%s' during merge_data." % key)
        secondobj.data.fieldrecipes_kwargs["snap"] = self
merge_hints(secondobj)

Merge hints from other snapshot into self.hints.

Parameters:

Name Type Description Default
secondobj

Other snapshot to merge hints from.

required

Returns:

Type Description
None
Source code in src/scida/customs/gadgetstyle/dataset.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def merge_hints(self, secondobj):
    """
    Merge hints from other snapshot into self.hints.

    Parameters
    ----------
    secondobj: GadgetStyleSnapshot
        Other snapshot to merge hints from.

    Returns
    -------
    None
    """
    # merge hints from snap and catalog
    for h in secondobj.hints:
        if h not in self.hints:
            self.hints[h] = secondobj.hints[h]
        elif isinstance(self.hints[h], dict):
            # merge dicts
            for k in secondobj.hints[h]:
                if k not in self.hints[h]:
                    self.hints[h][k] = secondobj.hints[h][k]
        else:
            pass  # nothing to do; we do not overwrite with catalog props
register_field(parttype, name=None, description='')

Register a field for a given particle type by returning through decorator.

Parameters:

Name Type Description Default
parttype

Particle type name to register with. If None, register for the base field container.

required
name

Name of the field to register.

None
description

Description of the field to register.

''

Returns:

Type Description
callable
Source code in src/scida/customs/gadgetstyle/dataset.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def register_field(self, parttype, name=None, description=""):
    """
    Register a field for a given particle type by returning through decorator.

    Parameters
    ----------
    parttype: Optional[Union[str, List[str]]]
        Particle type name to register with. If None, register for the base field container.
    name: Optional[str]
        Name of the field to register.
    description: Optional[str]
        Description of the field to register.

    Returns
    -------
    callable

    """
    res = self.data.register_field(parttype, name=name, description=description)
    return res
validate_path(path, *args, expect_grp=False, **kwargs) classmethod

Check if path is valid for this interface.

Parameters:

Name Type Description Default
path Union[str, PathLike]

path to check

required
args
()
kwargs
{}

Returns:

Type Description
bool
Source code in src/scida/customs/gadgetstyle/dataset.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, expect_grp=False, **kwargs
) -> CandidateStatus:
    """
    Check if path is valid for this interface.
    Parameters
    ----------
    path: str, os.PathLike
        path to check
    args
    kwargs

    Returns
    -------
    bool
    """
    path = str(path)
    possibly_valid = CandidateStatus.NO
    iszarr = path.rstrip("/").endswith(".zarr")
    if path.endswith(".hdf5") or iszarr:
        possibly_valid = CandidateStatus.MAYBE
    if os.path.isdir(path):
        files = os.listdir(path)
        sufxs = [f.split(".")[-1] for f in files]
        if not iszarr and len(set(sufxs)) > 1:
            possibly_valid = CandidateStatus.NO
        if sufxs[0] == "hdf5":
            possibly_valid = CandidateStatus.MAYBE
    if possibly_valid != CandidateStatus.NO:
        metadata_raw = load_metadata(path, **kwargs)
        # need some silly combination of attributes to be sure
        if all([k in metadata_raw for k in ["/Header"]]):
            # identifying snapshot or group catalog
            is_snap = all(
                [
                    k in metadata_raw["/Header"]
                    for k in ["NumPart_ThisFile", "NumPart_Total"]
                ]
            )
            is_grp = all(
                [
                    k in metadata_raw["/Header"]
                    for k in ["Ngroups_ThisFile", "Ngroups_Total"]
                ]
            )
            if is_grp:
                return CandidateStatus.MAYBE
            if is_snap and not expect_grp:
                return CandidateStatus.MAYBE
    return CandidateStatus.NO

series

Defines a series representing a Gadget-style simulation.

GadgetStyleSimulation

Bases: DatasetSeries

A series representing a Gadget-style simulation.

Source code in src/scida/customs/gadgetstyle/series.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class GadgetStyleSimulation(DatasetSeries):
    """A series representing a Gadget-style simulation."""

    def __init__(
        self,
        path,
        prefix_dict: Optional[Dict] = None,
        subpath_dict: Optional[Dict] = None,
        arg_dict: Optional[Dict] = None,
        lazy=True,
        **interface_kwargs
    ):
        """
        Initialize a GadgetStyleSimulation object.

        Parameters
        ----------
        path: str
            Path to the simulation folder, should contain "output" folder.
        prefix_dict: dict
        subpath_dict: dict
        arg_dict: dict
        lazy: bool
        interface_kwargs: dict
        """
        self.path = path
        self.name = os.path.basename(path)
        if prefix_dict is None:
            prefix_dict = dict()
        if subpath_dict is None:
            subpath_dict = dict()
        if arg_dict is None:
            arg_dict = dict()
        p = Path(path)
        if not (p.exists()):
            raise ValueError("Specified path '%s' does not exist." % path)
        paths_dict = dict()
        keys = []
        for d in [prefix_dict, subpath_dict, arg_dict]:
            keys.extend(list(d.keys()))
        keys = set(keys)
        for k in keys:
            subpath = subpath_dict.get(k, "output")
            sp = p / subpath

            prefix = _get_snapshotfolder_prefix(sp)
            prefix = prefix_dict.get(k, prefix)
            if not sp.exists():
                if k != "paths":
                    continue  # do not require optional sources
                raise ValueError("Specified path '%s' does not exist." % (p / subpath))
            fns = os.listdir(sp)
            prfxs = set([f.split("_")[0] for f in fns if f.startswith(prefix)])
            if len(prfxs) == 0:
                raise ValueError(
                    "Could not find any files with prefix '%s' in '%s'." % (prefix, sp)
                )
            prfx = prfxs.pop()

            paths = sorted([p for p in sp.glob(prfx + "_*")])
            # sometimes there are backup folders with different suffix, exclude those.
            paths = [
                p
                for p in paths
                if str(p).split("_")[-1].isdigit() or str(p).endswith(".hdf5")
            ]
            paths_dict[k] = paths

        # make sure we have the same amount of paths respectively
        length = None
        for k in paths_dict.keys():
            paths = paths_dict[k]
            if length is None:
                length = len(paths)
            else:
                assert length == len(paths)

        paths = paths_dict.pop("paths", None)
        if paths is None:
            raise ValueError("Could not find any snapshot paths.")
        p = paths[0]
        cls = _determine_type(p)[1][0]

        mixins = _determine_mixins(path=p)
        cls = create_datasetclass_with_mixins(cls, mixins)

        kwargs = {arg_dict.get(k, "catalog"): paths_dict[k] for k in paths_dict.keys()}
        kwargs.update(**interface_kwargs)

        super().__init__(paths, datasetclass=cls, lazy=lazy, **kwargs)
__init__(path, prefix_dict=None, subpath_dict=None, arg_dict=None, lazy=True, **interface_kwargs)

Initialize a GadgetStyleSimulation object.

Parameters:

Name Type Description Default
path

Path to the simulation folder, should contain "output" folder.

required
prefix_dict Optional[Dict]
None
subpath_dict Optional[Dict]
None
arg_dict Optional[Dict]
None
lazy
True
interface_kwargs
{}
Source code in src/scida/customs/gadgetstyle/series.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __init__(
    self,
    path,
    prefix_dict: Optional[Dict] = None,
    subpath_dict: Optional[Dict] = None,
    arg_dict: Optional[Dict] = None,
    lazy=True,
    **interface_kwargs
):
    """
    Initialize a GadgetStyleSimulation object.

    Parameters
    ----------
    path: str
        Path to the simulation folder, should contain "output" folder.
    prefix_dict: dict
    subpath_dict: dict
    arg_dict: dict
    lazy: bool
    interface_kwargs: dict
    """
    self.path = path
    self.name = os.path.basename(path)
    if prefix_dict is None:
        prefix_dict = dict()
    if subpath_dict is None:
        subpath_dict = dict()
    if arg_dict is None:
        arg_dict = dict()
    p = Path(path)
    if not (p.exists()):
        raise ValueError("Specified path '%s' does not exist." % path)
    paths_dict = dict()
    keys = []
    for d in [prefix_dict, subpath_dict, arg_dict]:
        keys.extend(list(d.keys()))
    keys = set(keys)
    for k in keys:
        subpath = subpath_dict.get(k, "output")
        sp = p / subpath

        prefix = _get_snapshotfolder_prefix(sp)
        prefix = prefix_dict.get(k, prefix)
        if not sp.exists():
            if k != "paths":
                continue  # do not require optional sources
            raise ValueError("Specified path '%s' does not exist." % (p / subpath))
        fns = os.listdir(sp)
        prfxs = set([f.split("_")[0] for f in fns if f.startswith(prefix)])
        if len(prfxs) == 0:
            raise ValueError(
                "Could not find any files with prefix '%s' in '%s'." % (prefix, sp)
            )
        prfx = prfxs.pop()

        paths = sorted([p for p in sp.glob(prfx + "_*")])
        # sometimes there are backup folders with different suffix, exclude those.
        paths = [
            p
            for p in paths
            if str(p).split("_")[-1].isdigit() or str(p).endswith(".hdf5")
        ]
        paths_dict[k] = paths

    # make sure we have the same amount of paths respectively
    length = None
    for k in paths_dict.keys():
        paths = paths_dict[k]
        if length is None:
            length = len(paths)
        else:
            assert length == len(paths)

    paths = paths_dict.pop("paths", None)
    if paths is None:
        raise ValueError("Could not find any snapshot paths.")
    p = paths[0]
    cls = _determine_type(p)[1][0]

    mixins = _determine_mixins(path=p)
    cls = create_datasetclass_with_mixins(cls, mixins)

    kwargs = {arg_dict.get(k, "catalog"): paths_dict[k] for k in paths_dict.keys()}
    kwargs.update(**interface_kwargs)

    super().__init__(paths, datasetclass=cls, lazy=lazy, **kwargs)

discovertypes

Functionality to determine the dataset or dataseries type of a given path.

CandidateStatus

Bases: Enum

Enum to indicate our confidence in a candidate.

Source code in src/scida/discovertypes.py
22
23
24
25
26
27
28
29
30
class CandidateStatus(Enum):
    """
    Enum to indicate our confidence in a candidate.
    """

    # TODO: Rethink how tu use MAYBE/YES information.
    NO = 0  # definitely not a candidate
    MAYBE = 1  # not sure yet
    YES = 2  # yes, this is a candidate

fields

DerivedFieldRecipe

Bases: FieldRecipe

Recipe for a derived field.

Source code in src/scida/fields.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class DerivedFieldRecipe(FieldRecipe):
    """
    Recipe for a derived field.
    """

    def __init__(self, name, func, description="", units=None):
        """See FieldRecipe for parameters."""
        super().__init__(
            name,
            func=func,
            description=description,
            units=units,
            ftype=FieldType.DERIVED,
        )

__init__(name, func, description='', units=None)

See FieldRecipe for parameters.

Source code in src/scida/fields.py
73
74
75
76
77
78
79
80
81
def __init__(self, name, func, description="", units=None):
    """See FieldRecipe for parameters."""
    super().__init__(
        name,
        func=func,
        description=description,
        units=units,
        ftype=FieldType.DERIVED,
    )

FieldContainer

Bases: MutableMapping

A mutable collection of fields. Attempt to construct from derived fields recipes if needed.

Source code in src/scida/fields.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
class FieldContainer(MutableMapping):
    """A mutable collection of fields. Attempt to construct from derived fields recipes
    if needed."""

    def __init__(
        self,
        *args,
        fieldrecipes_kwargs=None,
        containers=None,
        aliases=None,
        withunits=False,
        ureg=None,
        parent: Optional[FieldContainer] = None,
        name: Optional[str] = None,
        **kwargs,
    ):
        """
        Construct a FieldContainer.

        Parameters
        ----------
        args
        fieldrecipes_kwargs: dict
            default kwargs used for field recipes
        containers: List[FieldContainer, str]
            list of containers to add. FieldContainers in the list will be deep copied.
            If a list element is a string, a new FieldContainer with the given name will be created.
        aliases
        withunits
        ureg
        parent: Optional[FieldContainer]
            parent container
        kwargs
        """
        if aliases is None:
            aliases = {}
        if fieldrecipes_kwargs is None:
            fieldrecipes_kwargs = {}
        self.aliases = aliases
        self.name = name
        self._fields: Dict[str, da.Array] = {}
        self._fields.update(*args, **kwargs)
        self._fieldrecipes = {}
        self._fieldlength = None
        self.fieldrecipes_kwargs = fieldrecipes_kwargs
        self.withunits = withunits
        self._ureg: Optional[pint.UnitRegistry] = ureg
        self._containers: Dict[
            str, FieldContainer
        ] = dict()  # other containers as subgroups
        if containers is not None:
            for k in containers:
                self.add_container(k, deep=True)
        self.internals = ["uid"]  # names of internal fields/groups
        self.parent = parent

    def set_ureg(self, ureg=None, discover=True):
        """
        Set the unit registry.

        Parameters
        ----------
        ureg: pint.UnitRegistry
            Unit registry.
        discover: bool
            Attempt to discover unit registry from fields.

        Returns
        -------

        """
        if ureg is None and not discover:
            raise ValueError("Need to specify ureg or set discover=True.")
        if ureg is None and discover:
            keys = self.keys(withgroups=False, withrecipes=False, withinternal=True)
            for k in keys:
                if hasattr(self[k], "units"):
                    if isinstance(self[k].units, pint.Unit):
                        ureg = self[k].units._REGISTRY
        self._ureg = ureg

    def get_ureg(self, discover=True):
        """
        Get the unit registry.

        Returns
        -------

        """
        if self._ureg is None and discover:
            self.set_ureg(discover=True)
        return self._ureg

    def copy_skeleton(self) -> FieldContainer:
        """
        Copy the skeleton of the container (i.e., only the containers, not the fields).

        Returns
        -------
        FieldContainer
        """
        res = FieldContainer()
        for k, cntr in self._containers.items():
            res[k] = cntr.copy_skeleton()
        return res

    def info(self, level=0, name: Optional[str] = None) -> str:
        """
        Return a string representation of the object.

        Parameters
        ----------
        level: int
            Level in case of nested containers.
        name:
            Name of the container.

        Returns
        -------
        str
        """
        rep = ""
        length = self.fieldlength
        count = self.fieldcount
        if name is None:
            name = self.name
        ncontainers = len(self._containers)
        statstrs = []
        if length is not None and length > 0:
            statstrs.append("fields: %i" % count)
            statstrs.append("entries: %i" % length)
        if ncontainers > 0:
            statstrs.append("containers: %i" % ncontainers)
        if len(statstrs) > 0:
            statstr = ", ".join(statstrs)
            rep += sprint((level + 1) * "+", name, "(%s)" % statstr)
        for k in sorted(self._containers.keys()):
            v = self._containers[k]
            rep += v.info(level=level + 1)
        return rep

    def merge(self, collection: FieldContainer, overwrite: bool = True):
        """
        Merge another FieldContainer into this one.

        Parameters
        ----------
        collection: FieldContainer
            Container to merge.
        overwrite: bool
            Overwrite existing fields if true.

        Returns
        -------

        """
        if not isinstance(collection, FieldContainer):
            raise TypeError("Can only merge FieldContainers.")
        # TODO: support nested containers
        for k in collection._containers:
            if k not in self._containers:
                continue
            if overwrite:
                c1 = self._containers[k]
                c2 = collection._containers[k]
            else:
                c1 = collection._containers[k]
                c2 = self._containers[k]
            c1._fields.update(**c2._fields)
            c1._fieldrecipes.update(**c2._fieldrecipes)

    @property
    def fieldcount(self):
        """
        Return the number of fields.

        Returns
        -------
        int
        """
        rcps = set(self._fieldrecipes)
        flds = set([k for k in self._fields if k not in self.internals])
        ntot = len(rcps | flds)
        return ntot

    @property
    def fieldlength(self):
        """
        Try to infer the number of entries for the fields in this container.
        If all fields have the same length, return this length. Otherwise, return None.

        Returns
        -------
        Optional[int]
        """
        if self._fieldlength is not None:
            return self._fieldlength
        fvals = self._fields.values()
        itr = iter(fvals)
        if len(fvals) == 0:
            # can we infer from recipes?
            if len(self._fieldrecipes) > 0:
                # get first recipe
                name = next(iter(self._fieldrecipes.keys()))
                first = self._getitem(name, evaluate_recipe=True)
            else:
                return None
        else:
            first = next(itr)
        if all(first.shape[0] == v.shape[0] for v in self._fields.values()):
            self._fieldlength = first.shape[0]
            return self._fieldlength
        else:
            return None

    def keys(
        self,
        withgroups: bool = True,
        withrecipes: bool = True,
        withinternal: bool = False,
        withfields: bool = True,
    ):
        """
        Return a list of keys in the container.

        Parameters
        ----------
        withgroups: bool
            Include sub-containers.
        withrecipes: bool
            Include recipes (i.e. not yet instantiated fields).
        withinternal: bool
            Include internal fields.
        withfields: bool
            Include fields.

        Returns
        -------

        """
        fieldkeys = []
        recipekeys = []
        if withfields:
            fieldkeys = list(self._fields.keys())
            if not withinternal:
                for ikey in self.internals:
                    if ikey in fieldkeys:
                        fieldkeys.remove(ikey)
        if withrecipes:
            recipekeys = self._fieldrecipes.keys()
        fieldkeys = list(set(fieldkeys) | set(recipekeys))
        if withgroups:
            groupkeys = self._containers.keys()
            fieldkeys = list(set(fieldkeys) | set(groupkeys))
        return sorted(fieldkeys)

    def items(self, withrecipes=True, withfields=True, evaluate=True):
        """
        Return a list of tuples for keys/values in the container.

        Parameters
        ----------
        withrecipes: bool
            Whether to include recipes.
        withfields: bool
            Whether to include fields.
        evaluate: bool
            Whether to evaluate recipes.

        Returns
        -------
        list

        """
        return (
            (k, self._getitem(k, evaluate_recipe=evaluate))
            for k in self.keys(withrecipes=withrecipes, withfields=withfields)
        )

    def values(self, evaluate=True):
        """
        Return fields/recipes the container.

        Parameters
        ----------
        evaluate: bool
            Whether to evaluate recipes.

        Returns
        -------
        list

        """
        return (self._getitem(k, evaluate_recipe=evaluate) for k in self.keys())

    def register_field(
        self,
        containernames=None,
        name: Optional[str] = None,
        description="",
        units=None,
    ):
        """
        Decorator to register a field recipe.

        Parameters
        ----------
        containernames: Optional[Union[str, List[str]]]
            Name of the sub-container(s) to register to, or "all" for all, or None for self.
        name: Optional[str]
            Name of the field. If None, the function name is used.
        description: str
            Description of the field.
        units: Optional[Union[pint.Unit, str]]
            Units of the field.

        Returns
        -------
        callable

        """
        # we only construct field upon first call to it (default)
        # if to_containers, we register to the respective children containers
        containers = []
        if isinstance(containernames, list):
            containers = [self._containers[c] for c in containernames]
        elif containernames == "all":
            containers = self._containers.values()
        elif containernames is None:
            containers = [self]
        elif isinstance(containernames, str):  # just a single container as a string?
            containers.append(self._containers[containernames])
        else:
            raise ValueError("Unknown type.")

        def decorator(func, name=name, description=description, units=units):
            """
            Decorator to register a field recipe.
            """
            if name is None:
                name = func.__name__
            for container in containers:
                drvfields = container._fieldrecipes
                drvfields[name] = DerivedFieldRecipe(
                    name, func, description=description, units=units
                )
            return func

        return decorator

    def __setitem__(self, key, value):
        if key in self.aliases:
            key = self.aliases[key]
        if isinstance(value, FieldContainer):
            self._containers[key] = value
        elif isinstance(value, DerivedFieldRecipe):
            self._fieldrecipes[key] = value
        else:
            self._fields[key] = value

    def __getitem__(self, key):
        return self._getitem(key)

    def __iter__(self):
        return iter(self.keys())

    def __repr__(self) -> str:
        """
        Return a string representation of the object.
        Returns
        -------
        str
        """
        txt = ""
        txt += "FieldContainer[containers=%s, fields=%s]" % (
            len(self._containers),
            self.fieldcount,
        )
        return txt

    @property
    def dataframe(self):
        """
        Return a dask dataframe of the fields in this container.

        Returns
        -------
        dd.DataFrame

        """
        return self.get_dataframe()

    def get_dataframe(self, fields=None):
        """
        Return a dask dataframe of the fields in this container.

        Parameters
        ----------
        fields: Optional[List[str]]
            List of fields to include. If None, include all.

        Returns
        -------
        dd.DataFrame
        """
        dss = {}
        if fields is None:
            fields = self.keys()
        for k in fields:
            idim = None
            if k not in self.keys():
                # could still be an index two 2D dataset
                i = -1
                while k[i:].isnumeric():
                    i += -1
                i += 1
                if i == 0:
                    raise ValueError("Field '%s' not found" % k)
                idim = int(k[i:])
                k = k.split(k[i:])[0]
            v = self[k]
            assert v.ndim <= 2  # cannot support more than 2 here...
            if idim is not None:
                if v.ndim <= 1:
                    raise ValueError("No second dimensional index for %s" % k)
                if idim >= v.shape[1]:
                    raise ValueError(
                        "Second dimensional index %i not defined for %s" % (idim, k)
                    )

            if v.ndim > 1:
                for i in range(v.shape[1]):
                    if idim is None or idim == i:
                        dss[k + str(i)] = v[:, i]
            else:
                dss[k] = v
        dfs = []
        for k, v in dss.items():
            if isinstance(v, pint.Quantity):
                # pint quantities not supported yet in dd, so remove for now
                v = v.magnitude
            dfs.append(dd.from_dask_array(v, columns=[k]))
        ddf = dd.concat(dfs, axis=1)
        return ddf

    def add_alias(self, alias, name):
        """
        Add an alias for a field.

        Parameters
        ----------
        alias: str
            Alias name
        name: str
            Field name

        Returns
        -------
        None

        """
        self.aliases[alias] = name

    def remove_container(self, key):
        """
        Remove a sub-container.

        Parameters
        ----------
        key: str
            Name of the sub-container.

        Returns
        -------
        None
        """
        if key in self._containers:
            del self._containers[key]
        else:
            raise KeyError("Unknown container '%s'" % key)

    def add_container(self, key, deep=False, **kwargs):
        """
        Add a sub-container.

        Parameters
        ----------
        key: str, FieldContainer
        deep: bool
            If True, make a deep copy of the container.
        kwargs: dict
            keyword arguments for the FieldContainer constructor.

        Returns
        -------
        None
        """
        if isinstance(key, str):
            # create a new container with given name
            tkwargs = dict(**kwargs)
            if "name" not in tkwargs:
                tkwargs["name"] = key
            self._containers[key] = FieldContainer(
                fieldrecipes_kwargs=self.fieldrecipes_kwargs,
                withunits=self.withunits,
                ureg=self.get_ureg(),
                parent=self,
                **tkwargs,
            )
        elif isinstance(key, FieldContainer):
            # now we do a shallow or deep copy
            name = kwargs.pop("name", key.name)
            if deep:
                self._containers[name] = key.copy()
            else:
                self._containers[name] = key
        else:
            raise ValueError("Unknown type.")

    def copy(self):
        """
        Perform a deep (?) copy of the FieldContainer.

        Returns
        -------
        FieldContainer
        """
        instance = self.__class__()
        instance._fields = self._fields.copy()
        instance._fieldrecipes = self._fieldrecipes.copy()
        instance.aliases = self.aliases.copy()
        instance.fieldrecipes_kwargs = self.fieldrecipes_kwargs.copy()
        instance.withunits = self.withunits
        instance._ureg = self._ureg
        instance.internals = self.internals.copy()
        instance.parent = self.parent
        for k, v in self._containers.items():
            instance.add_container(v.copy(), deep=True, name=k)

        return instance

    def _getitem(
        self, key, force_derived=False, update_dict=True, evaluate_recipe=True
    ):
        """
        Get an item from the container.

        Parameters
        ----------
        key: str
        force_derived: bool
            Use the derived field description over instantiated fields.
        update_dict: bool
            Update the dictionary of instantiated fields.
        evaluate_recipe: bool
            Evaluate the recipe.

        Returns
        -------
        da.Array

        """
        if key in self.aliases:
            key = self.aliases[key]
        if key in self._containers:
            return self._containers[key]
        if key in self._fields and not force_derived:
            return self._fields[key]
        else:
            if key in self._fieldrecipes:
                if not evaluate_recipe:
                    return self._fieldrecipes[key]
                field = self._instantiate_field(key)
                if update_dict:
                    self._fields[key] = field
                return field
            else:
                raise KeyError("Unknown field '%s'" % key)

    def _instantiate_field(self, key):
        """
        Instantiate a field from a recipe, i.e. create its dask array.

        Parameters
        ----------
        key: str
            Name of the field.

        Returns
        -------
        da.Array
        """
        func = self._fieldrecipes[key].func
        units = self._fieldrecipes[key].units
        accept_kwargs = inspect.getfullargspec(func).varkw is not None
        func_kwargs = get_kwargs(func)
        dkwargs = self.fieldrecipes_kwargs
        ureg = None
        if "ureg" not in dkwargs:
            ureg = self.get_ureg()
            dkwargs["ureg"] = ureg
        # first, we overwrite all optional arguments with class instance defaults where func kwarg is None
        kwargs = {
            k: dkwargs[k]
            for k in (
                set(dkwargs) & set([k for k, v in func_kwargs.items() if v is None])
            )
        }
        # next, we add all optional arguments if func is accepting **kwargs and varname not yet in signature
        if accept_kwargs:
            kwargs.update(
                **{
                    k: v
                    for k, v in dkwargs.items()
                    if k not in inspect.getfullargspec(func).args
                }
            )
        # finally, instantiate field
        field = func(self, **kwargs)
        if self.withunits and units is not None:
            if not hasattr(field, "units"):
                field = field * units
            else:
                has_reg1 = hasattr(field.units, "_REGISTRY")
                has_reg2 = hasattr(units, "_REGISTRY")
                has_regs = has_reg1 and has_reg2
                if has_regs:
                    if field.units._REGISTRY == units._REGISTRY:
                        if field.units != units:
                            # if unit is present, but unit from metadata is unknown,
                            # we stick with the former
                            if not (
                                hasattr(units, "units")
                                and str(units.units) == "unknown"
                            ):
                                try:
                                    field = field.to(units)
                                except pint.errors.DimensionalityError as e:
                                    print(e)
                                    raise ValueError(
                                        "Field '%s' units '%s' do not match '%s'"
                                        % (key, field.units, units)
                                    )
                    else:
                        # this should not happen. TODO: figure out when this happens
                        logging.warning(
                            "Unit registries of field '%s' do not match. container registry."
                            % key
                        )
        return field

    def __delitem__(self, key):
        if key in self._fieldrecipes:
            del self._fieldrecipes[key]
        if key in self._containers:
            del self._containers[key]
        elif key in self._fields:
            del self._fields[key]
        else:
            raise KeyError("Unknown key '%s'" % key)

    def __len__(self):
        return len(self.keys())

    def get(self, key, value=None, allow_derived=True, force_derived=False):
        """
        Get a field.

        Parameters
        ----------
        key: str
        value: da.Array
        allow_derived: bool
            Allow derived fields.
        force_derived: bool
            Use the derived field description over instantiated fields.

        Returns
        -------
        da.Array
        """
        if key in self._fieldrecipes and not allow_derived:
            raise KeyError("Field '%s' is derived (allow_derived=False)" % key)
        else:
            try:
                return self._getitem(
                    key, force_derived=force_derived, update_dict=False
                )
            except KeyError:
                return value

dataframe property

Return a dask dataframe of the fields in this container.

Returns:

Type Description
DataFrame

fieldcount property

Return the number of fields.

Returns:

Type Description
int

fieldlength property

Try to infer the number of entries for the fields in this container. If all fields have the same length, return this length. Otherwise, return None.

Returns:

Type Description
Optional[int]

__init__(*args, fieldrecipes_kwargs=None, containers=None, aliases=None, withunits=False, ureg=None, parent=None, name=None, **kwargs)

Construct a FieldContainer.

Parameters:

Name Type Description Default
args
()
fieldrecipes_kwargs

default kwargs used for field recipes

None
containers

list of containers to add. FieldContainers in the list will be deep copied. If a list element is a string, a new FieldContainer with the given name will be created.

None
aliases
None
withunits
False
ureg
None
parent Optional[FieldContainer]

parent container

None
kwargs
{}
Source code in src/scida/fields.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    *args,
    fieldrecipes_kwargs=None,
    containers=None,
    aliases=None,
    withunits=False,
    ureg=None,
    parent: Optional[FieldContainer] = None,
    name: Optional[str] = None,
    **kwargs,
):
    """
    Construct a FieldContainer.

    Parameters
    ----------
    args
    fieldrecipes_kwargs: dict
        default kwargs used for field recipes
    containers: List[FieldContainer, str]
        list of containers to add. FieldContainers in the list will be deep copied.
        If a list element is a string, a new FieldContainer with the given name will be created.
    aliases
    withunits
    ureg
    parent: Optional[FieldContainer]
        parent container
    kwargs
    """
    if aliases is None:
        aliases = {}
    if fieldrecipes_kwargs is None:
        fieldrecipes_kwargs = {}
    self.aliases = aliases
    self.name = name
    self._fields: Dict[str, da.Array] = {}
    self._fields.update(*args, **kwargs)
    self._fieldrecipes = {}
    self._fieldlength = None
    self.fieldrecipes_kwargs = fieldrecipes_kwargs
    self.withunits = withunits
    self._ureg: Optional[pint.UnitRegistry] = ureg
    self._containers: Dict[
        str, FieldContainer
    ] = dict()  # other containers as subgroups
    if containers is not None:
        for k in containers:
            self.add_container(k, deep=True)
    self.internals = ["uid"]  # names of internal fields/groups
    self.parent = parent

__repr__()

Return a string representation of the object.

Returns:

Type Description
str
Source code in src/scida/fields.py
450
451
452
453
454
455
456
457
458
459
460
461
462
def __repr__(self) -> str:
    """
    Return a string representation of the object.
    Returns
    -------
    str
    """
    txt = ""
    txt += "FieldContainer[containers=%s, fields=%s]" % (
        len(self._containers),
        self.fieldcount,
    )
    return txt

add_alias(alias, name)

Add an alias for a field.

Parameters:

Name Type Description Default
alias

Alias name

required
name

Field name

required

Returns:

Type Description
None
Source code in src/scida/fields.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def add_alias(self, alias, name):
    """
    Add an alias for a field.

    Parameters
    ----------
    alias: str
        Alias name
    name: str
        Field name

    Returns
    -------
    None

    """
    self.aliases[alias] = name

add_container(key, deep=False, **kwargs)

Add a sub-container.

Parameters:

Name Type Description Default
key
required
deep

If True, make a deep copy of the container.

False
kwargs

keyword arguments for the FieldContainer constructor.

{}

Returns:

Type Description
None
Source code in src/scida/fields.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
def add_container(self, key, deep=False, **kwargs):
    """
    Add a sub-container.

    Parameters
    ----------
    key: str, FieldContainer
    deep: bool
        If True, make a deep copy of the container.
    kwargs: dict
        keyword arguments for the FieldContainer constructor.

    Returns
    -------
    None
    """
    if isinstance(key, str):
        # create a new container with given name
        tkwargs = dict(**kwargs)
        if "name" not in tkwargs:
            tkwargs["name"] = key
        self._containers[key] = FieldContainer(
            fieldrecipes_kwargs=self.fieldrecipes_kwargs,
            withunits=self.withunits,
            ureg=self.get_ureg(),
            parent=self,
            **tkwargs,
        )
    elif isinstance(key, FieldContainer):
        # now we do a shallow or deep copy
        name = kwargs.pop("name", key.name)
        if deep:
            self._containers[name] = key.copy()
        else:
            self._containers[name] = key
    else:
        raise ValueError("Unknown type.")

copy()

Perform a deep (?) copy of the FieldContainer.

Returns:

Type Description
FieldContainer
Source code in src/scida/fields.py
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def copy(self):
    """
    Perform a deep (?) copy of the FieldContainer.

    Returns
    -------
    FieldContainer
    """
    instance = self.__class__()
    instance._fields = self._fields.copy()
    instance._fieldrecipes = self._fieldrecipes.copy()
    instance.aliases = self.aliases.copy()
    instance.fieldrecipes_kwargs = self.fieldrecipes_kwargs.copy()
    instance.withunits = self.withunits
    instance._ureg = self._ureg
    instance.internals = self.internals.copy()
    instance.parent = self.parent
    for k, v in self._containers.items():
        instance.add_container(v.copy(), deep=True, name=k)

    return instance

copy_skeleton()

Copy the skeleton of the container (i.e., only the containers, not the fields).

Returns:

Type Description
FieldContainer
Source code in src/scida/fields.py
177
178
179
180
181
182
183
184
185
186
187
188
def copy_skeleton(self) -> FieldContainer:
    """
    Copy the skeleton of the container (i.e., only the containers, not the fields).

    Returns
    -------
    FieldContainer
    """
    res = FieldContainer()
    for k, cntr in self._containers.items():
        res[k] = cntr.copy_skeleton()
    return res

get(key, value=None, allow_derived=True, force_derived=False)

Get a field.

Parameters:

Name Type Description Default
key
required
value
None
allow_derived

Allow derived fields.

True
force_derived

Use the derived field description over instantiated fields.

False

Returns:

Type Description
Array
Source code in src/scida/fields.py
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
def get(self, key, value=None, allow_derived=True, force_derived=False):
    """
    Get a field.

    Parameters
    ----------
    key: str
    value: da.Array
    allow_derived: bool
        Allow derived fields.
    force_derived: bool
        Use the derived field description over instantiated fields.

    Returns
    -------
    da.Array
    """
    if key in self._fieldrecipes and not allow_derived:
        raise KeyError("Field '%s' is derived (allow_derived=False)" % key)
    else:
        try:
            return self._getitem(
                key, force_derived=force_derived, update_dict=False
            )
        except KeyError:
            return value

get_dataframe(fields=None)

Return a dask dataframe of the fields in this container.

Parameters:

Name Type Description Default
fields

List of fields to include. If None, include all.

None

Returns:

Type Description
DataFrame
Source code in src/scida/fields.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
def get_dataframe(self, fields=None):
    """
    Return a dask dataframe of the fields in this container.

    Parameters
    ----------
    fields: Optional[List[str]]
        List of fields to include. If None, include all.

    Returns
    -------
    dd.DataFrame
    """
    dss = {}
    if fields is None:
        fields = self.keys()
    for k in fields:
        idim = None
        if k not in self.keys():
            # could still be an index two 2D dataset
            i = -1
            while k[i:].isnumeric():
                i += -1
            i += 1
            if i == 0:
                raise ValueError("Field '%s' not found" % k)
            idim = int(k[i:])
            k = k.split(k[i:])[0]
        v = self[k]
        assert v.ndim <= 2  # cannot support more than 2 here...
        if idim is not None:
            if v.ndim <= 1:
                raise ValueError("No second dimensional index for %s" % k)
            if idim >= v.shape[1]:
                raise ValueError(
                    "Second dimensional index %i not defined for %s" % (idim, k)
                )

        if v.ndim > 1:
            for i in range(v.shape[1]):
                if idim is None or idim == i:
                    dss[k + str(i)] = v[:, i]
        else:
            dss[k] = v
    dfs = []
    for k, v in dss.items():
        if isinstance(v, pint.Quantity):
            # pint quantities not supported yet in dd, so remove for now
            v = v.magnitude
        dfs.append(dd.from_dask_array(v, columns=[k]))
    ddf = dd.concat(dfs, axis=1)
    return ddf

get_ureg(discover=True)

Get the unit registry.

Source code in src/scida/fields.py
165
166
167
168
169
170
171
172
173
174
175
def get_ureg(self, discover=True):
    """
    Get the unit registry.

    Returns
    -------

    """
    if self._ureg is None and discover:
        self.set_ureg(discover=True)
    return self._ureg

info(level=0, name=None)

Return a string representation of the object.

Parameters:

Name Type Description Default
level

Level in case of nested containers.

0
name Optional[str]

Name of the container.

None

Returns:

Type Description
str
Source code in src/scida/fields.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def info(self, level=0, name: Optional[str] = None) -> str:
    """
    Return a string representation of the object.

    Parameters
    ----------
    level: int
        Level in case of nested containers.
    name:
        Name of the container.

    Returns
    -------
    str
    """
    rep = ""
    length = self.fieldlength
    count = self.fieldcount
    if name is None:
        name = self.name
    ncontainers = len(self._containers)
    statstrs = []
    if length is not None and length > 0:
        statstrs.append("fields: %i" % count)
        statstrs.append("entries: %i" % length)
    if ncontainers > 0:
        statstrs.append("containers: %i" % ncontainers)
    if len(statstrs) > 0:
        statstr = ", ".join(statstrs)
        rep += sprint((level + 1) * "+", name, "(%s)" % statstr)
    for k in sorted(self._containers.keys()):
        v = self._containers[k]
        rep += v.info(level=level + 1)
    return rep

items(withrecipes=True, withfields=True, evaluate=True)

Return a list of tuples for keys/values in the container.

Parameters:

Name Type Description Default
withrecipes

Whether to include recipes.

True
withfields

Whether to include fields.

True
evaluate

Whether to evaluate recipes.

True

Returns:

Type Description
list
Source code in src/scida/fields.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def items(self, withrecipes=True, withfields=True, evaluate=True):
    """
    Return a list of tuples for keys/values in the container.

    Parameters
    ----------
    withrecipes: bool
        Whether to include recipes.
    withfields: bool
        Whether to include fields.
    evaluate: bool
        Whether to evaluate recipes.

    Returns
    -------
    list

    """
    return (
        (k, self._getitem(k, evaluate_recipe=evaluate))
        for k in self.keys(withrecipes=withrecipes, withfields=withfields)
    )

keys(withgroups=True, withrecipes=True, withinternal=False, withfields=True)

Return a list of keys in the container.

Parameters:

Name Type Description Default
withgroups bool

Include sub-containers.

True
withrecipes bool

Include recipes (i.e. not yet instantiated fields).

True
withinternal bool

Include internal fields.

False
withfields bool

Include fields.

True
Source code in src/scida/fields.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def keys(
    self,
    withgroups: bool = True,
    withrecipes: bool = True,
    withinternal: bool = False,
    withfields: bool = True,
):
    """
    Return a list of keys in the container.

    Parameters
    ----------
    withgroups: bool
        Include sub-containers.
    withrecipes: bool
        Include recipes (i.e. not yet instantiated fields).
    withinternal: bool
        Include internal fields.
    withfields: bool
        Include fields.

    Returns
    -------

    """
    fieldkeys = []
    recipekeys = []
    if withfields:
        fieldkeys = list(self._fields.keys())
        if not withinternal:
            for ikey in self.internals:
                if ikey in fieldkeys:
                    fieldkeys.remove(ikey)
    if withrecipes:
        recipekeys = self._fieldrecipes.keys()
    fieldkeys = list(set(fieldkeys) | set(recipekeys))
    if withgroups:
        groupkeys = self._containers.keys()
        fieldkeys = list(set(fieldkeys) | set(groupkeys))
    return sorted(fieldkeys)

merge(collection, overwrite=True)

Merge another FieldContainer into this one.

Parameters:

Name Type Description Default
collection FieldContainer

Container to merge.

required
overwrite bool

Overwrite existing fields if true.

True
Source code in src/scida/fields.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def merge(self, collection: FieldContainer, overwrite: bool = True):
    """
    Merge another FieldContainer into this one.

    Parameters
    ----------
    collection: FieldContainer
        Container to merge.
    overwrite: bool
        Overwrite existing fields if true.

    Returns
    -------

    """
    if not isinstance(collection, FieldContainer):
        raise TypeError("Can only merge FieldContainers.")
    # TODO: support nested containers
    for k in collection._containers:
        if k not in self._containers:
            continue
        if overwrite:
            c1 = self._containers[k]
            c2 = collection._containers[k]
        else:
            c1 = collection._containers[k]
            c2 = self._containers[k]
        c1._fields.update(**c2._fields)
        c1._fieldrecipes.update(**c2._fieldrecipes)

register_field(containernames=None, name=None, description='', units=None)

Decorator to register a field recipe.

Parameters:

Name Type Description Default
containernames

Name of the sub-container(s) to register to, or "all" for all, or None for self.

None
name Optional[str]

Name of the field. If None, the function name is used.

None
description

Description of the field.

''
units

Units of the field.

None

Returns:

Type Description
callable
Source code in src/scida/fields.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def register_field(
    self,
    containernames=None,
    name: Optional[str] = None,
    description="",
    units=None,
):
    """
    Decorator to register a field recipe.

    Parameters
    ----------
    containernames: Optional[Union[str, List[str]]]
        Name of the sub-container(s) to register to, or "all" for all, or None for self.
    name: Optional[str]
        Name of the field. If None, the function name is used.
    description: str
        Description of the field.
    units: Optional[Union[pint.Unit, str]]
        Units of the field.

    Returns
    -------
    callable

    """
    # we only construct field upon first call to it (default)
    # if to_containers, we register to the respective children containers
    containers = []
    if isinstance(containernames, list):
        containers = [self._containers[c] for c in containernames]
    elif containernames == "all":
        containers = self._containers.values()
    elif containernames is None:
        containers = [self]
    elif isinstance(containernames, str):  # just a single container as a string?
        containers.append(self._containers[containernames])
    else:
        raise ValueError("Unknown type.")

    def decorator(func, name=name, description=description, units=units):
        """
        Decorator to register a field recipe.
        """
        if name is None:
            name = func.__name__
        for container in containers:
            drvfields = container._fieldrecipes
            drvfields[name] = DerivedFieldRecipe(
                name, func, description=description, units=units
            )
        return func

    return decorator

remove_container(key)

Remove a sub-container.

Parameters:

Name Type Description Default
key

Name of the sub-container.

required

Returns:

Type Description
None
Source code in src/scida/fields.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
def remove_container(self, key):
    """
    Remove a sub-container.

    Parameters
    ----------
    key: str
        Name of the sub-container.

    Returns
    -------
    None
    """
    if key in self._containers:
        del self._containers[key]
    else:
        raise KeyError("Unknown container '%s'" % key)

set_ureg(ureg=None, discover=True)

Set the unit registry.

Parameters:

Name Type Description Default
ureg

Unit registry.

None
discover

Attempt to discover unit registry from fields.

True
Source code in src/scida/fields.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def set_ureg(self, ureg=None, discover=True):
    """
    Set the unit registry.

    Parameters
    ----------
    ureg: pint.UnitRegistry
        Unit registry.
    discover: bool
        Attempt to discover unit registry from fields.

    Returns
    -------

    """
    if ureg is None and not discover:
        raise ValueError("Need to specify ureg or set discover=True.")
    if ureg is None and discover:
        keys = self.keys(withgroups=False, withrecipes=False, withinternal=True)
        for k in keys:
            if hasattr(self[k], "units"):
                if isinstance(self[k].units, pint.Unit):
                    ureg = self[k].units._REGISTRY
    self._ureg = ureg

values(evaluate=True)

Return fields/recipes the container.

Parameters:

Name Type Description Default
evaluate

Whether to evaluate recipes.

True

Returns:

Type Description
list
Source code in src/scida/fields.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def values(self, evaluate=True):
    """
    Return fields/recipes the container.

    Parameters
    ----------
    evaluate: bool
        Whether to evaluate recipes.

    Returns
    -------
    list

    """
    return (self._getitem(k, evaluate_recipe=evaluate) for k in self.keys())

FieldRecipe

Bases: object

Recipe for a field.

Source code in src/scida/fields.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class FieldRecipe(object):
    """
    Recipe for a field.
    """

    def __init__(
        self, name, func=None, arr=None, description="", units=None, ftype=FieldType.IO
    ):
        """
        Recipes for a field. Either specify a function or an array.

        Parameters
        ----------
        name: str
            Name of the field.
        func: Optional[callable]
            Function to construct array of the field.
        arr: Optional[da.Array]
            Array to construct the field.
        description: str
            Description of the field.
        units: Optional[Union[pint.Unit, str]]
            Units of the field.
        ftype: FieldType
            Type of the field.
        """
        if func is None and arr is None:
            raise ValueError("Need to specify either func or arr.")
        self.type = ftype
        self.name = name
        self.description = description
        self.units = units
        self.func = func
        self.arr = arr

__init__(name, func=None, arr=None, description='', units=None, ftype=FieldType.IO)

Recipes for a field. Either specify a function or an array.

Parameters:

Name Type Description Default
name

Name of the field.

required
func

Function to construct array of the field.

None
arr

Array to construct the field.

None
description

Description of the field.

''
units

Units of the field.

None
ftype

Type of the field.

IO
Source code in src/scida/fields.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self, name, func=None, arr=None, description="", units=None, ftype=FieldType.IO
):
    """
    Recipes for a field. Either specify a function or an array.

    Parameters
    ----------
    name: str
        Name of the field.
    func: Optional[callable]
        Function to construct array of the field.
    arr: Optional[da.Array]
        Array to construct the field.
    description: str
        Description of the field.
    units: Optional[Union[pint.Unit, str]]
        Units of the field.
    ftype: FieldType
        Type of the field.
    """
    if func is None and arr is None:
        raise ValueError("Need to specify either func or arr.")
    self.type = ftype
    self.name = name
    self.description = description
    self.units = units
    self.func = func
    self.arr = arr

FieldType

Bases: Enum

Enum for field types.

Source code in src/scida/fields.py
22
23
24
25
26
27
28
29
class FieldType(Enum):
    """
    Enum for field types.
    """

    INTERNAL = 1  # for internal use only
    IO = 2  # from disk
    DERIVED = 3  # derived from other fields

walk_container(cntr, path='', handler_field=None, handler_group=None, withrecipes=False)

Recursively walk a container and call handlers on fields and groups.

Parameters:

Name Type Description Default
cntr

Container to walk.

required
path

relative path in hierarchy to this container

''
handler_field

Function to call on fields.

None
handler_group

Function to call on subcontainers.

None
withrecipes

Include recipes.

False

Returns:

Type Description
None
Source code in src/scida/fields.py
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def walk_container(
    cntr, path="", handler_field=None, handler_group=None, withrecipes=False
):
    """
    Recursively walk a container and call handlers on fields and groups.

    Parameters
    ----------
    cntr: FieldContainer
        Container to walk.
    path: str
        relative path in hierarchy to this container
    handler_field: callable
        Function to call on fields.
    handler_group: callable
        Function to call on subcontainers.
    withrecipes: bool
        Include recipes.

    Returns
    -------
    None
    """
    keykwargs = dict(withgroups=True, withrecipes=withrecipes)
    for ck in cntr.keys(**keykwargs):
        # we do not want to instantiate entry from recipe by calling cntr[ck] here
        entry = cntr[ck]
        newpath = path + "/" + ck
        if isinstance(entry, FieldContainer):
            if handler_group is not None:
                handler_group(entry, newpath)
            walk_container(
                entry,
                newpath,
                handler_field,
                handler_group,
                withrecipes=withrecipes,
            )
        else:
            if handler_field is not None:
                handler_field(entry, newpath, parent=cntr)

helpers_hdf5

Helper functions for hdf5 and zarr file processing.

create_mergedhdf5file(fn, files, max_workers=None, virtual=True, groupwise_shape=False)

Creates a virtual hdf5 file from list of given files. Virtual by default.

Parameters:

Name Type Description Default
fn

file to write to

required
files

files to merge

required
max_workers

parallel workers to process files

None
virtual

whether to create linked ("virtual") dataset on disk (otherwise copy)

True
groupwise_shape

whether to require shapes to be the same within a group

False

Returns:

Type Description
None
Source code in src/scida/helpers_hdf5.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def create_mergedhdf5file(
    fn, files, max_workers=None, virtual=True, groupwise_shape=False
):
    """
    Creates a virtual hdf5 file from list of given files. Virtual by default.

    Parameters
    ----------
    fn: str
        file to write to
    files: list
        files to merge
    max_workers: int
        parallel workers to process files
    virtual: bool
        whether to create linked ("virtual") dataset on disk (otherwise copy)
    groupwise_shape: bool
        whether to require shapes to be the same within a group

    Returns
    -------
    None
    """
    if max_workers is None:
        # read from config
        config = get_config()
        max_workers = config.get("nthreads", 16)
    # first obtain all datasets and groups
    trees = [{} for i in range(len(files))]

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        result = executor.map(walk_hdf5file, files, trees)
    result = list(result)

    groups = set([item for r in result for item in r["groups"]])
    datasets = set([item[0] for r in result for item in r["datasets"]])

    def todct(lst):
        """helper func"""
        return {item[0]: (item[1], item[2]) for item in lst["datasets"]}

    dcts = [todct(lst) for lst in result]
    shapes = OrderedDict((d, {}) for d in datasets)

    def shps(i, k, s):
        """helper func"""
        return shapes[k].update({i: s[0]}) if s is not None else None

    [shps(i, k, dct.get(k)) for k in datasets for i, dct in enumerate(dcts)]
    dtypes = {}

    def dtps(k, s):
        """helper func"""
        return dtypes.update({k: s[1]}) if s is not None else None

    [dtps(k, dct.get(k)) for k in datasets for i, dct in enumerate(dcts)]

    # get shapes of the respective chunks
    chunks = {}
    for field in sorted(shapes.keys()):
        chunks[field] = [[k, shapes[field][k][0]] for k in shapes[field]]
    groupchunks = {}

    # assert that all datasets in a given group have the same chunks.
    for group in sorted(groups):
        if group == "/":
            group = ""  # this is needed to have consistent levels for 0th level
        groupfields = [f for f in shapes.keys() if f.startswith(group)]
        groupfields = [f for f in groupfields if f.count("/") - 1 == group.count("/")]
        groupfields = sorted(groupfields)
        if len(groupfields) == 0:
            continue
        arr0 = chunks[groupfields[0]]
        for field in groupfields[1:]:
            arr = np.array(chunks[field])
            if groupwise_shape and not np.array_equal(arr0, arr):
                raise ValueError("Requiring same shape (see 'groupwise_shape' flag)")
            # then save the chunking information for this group
            groupchunks[field] = arr0

    # next fill merger file
    with h5py.File(fn, "w", libver="latest") as hf:
        # create groups
        for group in sorted(groups):
            if group == "/":
                continue  # nothing to do.
            hf.create_group(group)
            groupfields = [
                field
                for field in shapes.keys()
                if field.startswith(group) and field.count("/") - 1 == group.count("/")
            ]
            if len(groupfields) == 0:
                continue

            # fill fields
            if virtual:
                # for virtual datasets, iterate over all fields and concat each file to virtual dataset
                for field in groupfields:
                    totentries = np.array([k[1] for k in chunks[field]]).sum()
                    newshape = (totentries,) + shapes[field][next(iter(shapes[field]))][
                        1:
                    ]

                    # create virtual sources
                    vsources = []
                    for k in shapes[field]:
                        vsources.append(
                            h5py.VirtualSource(
                                files[k],
                                name=field,
                                shape=shapes[field][k],
                                dtype=dtypes[field],
                            )
                        )
                    layout = h5py.VirtualLayout(
                        shape=tuple(newshape), dtype=dtypes[field]
                    )

                    # fill virtual dataset
                    offset = 0
                    for vsource in vsources:
                        length = vsource.shape[0]
                        layout[offset : offset + length] = vsource
                        offset += length
                    assert (
                        newshape[0] == offset
                    )  # make sure we filled the array up fully.
                    hf.create_virtual_dataset(field, layout)
            else:  # copied dataset. For performance, we iterate differently: Loop over each file's fields
                for field in groupfields:
                    totentries = np.array([k[1] for k in chunks[field]]).sum()
                    extrashapes = shapes[field][next(iter(shapes[field]))][1:]
                    newshape = (totentries,) + extrashapes
                    hf.create_dataset(field, shape=newshape, dtype=dtypes[field])
                counters = {field: 0 for field in groupfields}
                for k, fl in enumerate(files):
                    with h5py.File(fl) as hf_load:
                        for field in groupfields:
                            n = shapes[field].get(k, [0, 0])[0]
                            if n == 0:
                                continue
                            offset = counters[field]
                            hf[field][offset : offset + n] = hf_load[field]
                            counters[field] = offset + n

        # save information regarding chunks
        grp = hf.create_group("_chunks")
        for k, v in groupchunks.items():
            grp.attrs[k] = v

        # write the attributes
        # find attributes that change across data sets
        attrs_key_lists = [
            list(v["attrs"].keys()) for v in result
        ]  # attribute paths for each file
        attrspaths_all = set().union(*attrs_key_lists)
        attrspaths_intersec = set(attrspaths_all).intersection(*attrs_key_lists)
        attrspath_diff = attrspaths_all.difference(attrspaths_intersec)
        if attrspaths_all != attrspaths_intersec:
            # if difference only stems from missing datasets (and their assoc. attrs); thats fine
            if not attrspath_diff.issubset(datasets):
                raise NotImplementedError(
                    "Some attribute paths not present in each partial data file."
                )
        # check for common key+values across all files
        attrs_same = {}
        attrs_differ = {}

        nfiles = len(files)

        for apath in sorted(attrspaths_all):
            attrs_same[apath] = {}
            attrs_differ[apath] = {}
            attrsnames = set().union(
                *[
                    result[i]["attrs"][apath]
                    for i in range(nfiles)
                    if apath in result[i]["attrs"]
                ]
            )
            for k in attrsnames:
                # we ignore apaths and k existing in some files.
                attrvallist = [
                    result[i]["attrs"][apath][k]
                    for i in range(nfiles)
                    if apath in result[i]["attrs"] and k in result[i]["attrs"][apath]
                ]
                attrval0 = attrvallist[0]
                if isinstance(attrval0, np.ndarray):
                    if not (np.all([np.array_equal(attrval0, v) for v in attrvallist])):
                        log.debug("%s: %s has different values." % (apath, k))
                        attrs_differ[apath][k] = np.stack(attrvallist)
                        continue
                else:
                    same = len(set(attrvallist)) == 1
                    if isinstance(attrval0, np.floating):
                        # for floats we do not require binary equality
                        # (we had some incident...)
                        same = np.allclose(attrval0, attrvallist)
                    if not same:
                        log.debug("%s: %s has different values." % (apath, k))
                        attrs_differ[apath][k] = np.array(attrvallist)
                        continue
                attrs_same[apath][k] = attrval0
        for apath in attrspaths_all:
            for k, v in attrs_same.get(apath, {}).items():
                hf[apath].attrs[k] = v
            for k, v in attrs_differ.get(apath, {}).items():
                hf[apath].attrs[k] = v

get_dtype(obj)

Get the data type of given h5py.Dataset or zarr.Array object

Parameters:

Name Type Description Default
obj

object to get dtype from

required

Returns:

Name Type Description
dtype dtype

dtype of the object

Source code in src/scida/helpers_hdf5.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def get_dtype(obj):
    """
    Get the data type of given h5py.Dataset or zarr.Array object

    Parameters
    ----------
    obj: h5py.Dataset or zarr.Array
        object to get dtype from

    Returns
    -------
    dtype: numpy.dtype
        dtype of the object
    """
    if isinstance(obj, h5py.Dataset):
        try:
            dtype = obj.dtype
        except TypeError as e:
            msg = "data type '<u6' not understood"
            if msg == e.__str__():
                # MTNG defines 6 byte unsigned integers, which are not supported by h5py
                # could not figure out how to query type in h5py other than the reporting error.
                # (any call to .dtype will try to resolve "<u6" to a numpy dtype, which fails)
                # we just handle this as 64 bit unsigned integer
                dtype = np.uint64
            else:
                raise e
        return dtype
    elif isinstance(obj, zarr.Array):
        return obj.dtype
    else:
        return None

walk_group(obj, tree, get_attrs=False, scalar_to_attr=True)

Walks through a h5py.Group or zarr.Group object and fills the tree dictionary with information about the datasets and groups.

Parameters:

Name Type Description Default
obj

object to walk through

required
tree

dictionary to fill recursively

required
get_attrs

whether to get attributes of each object

False
scalar_to_attr

whether to convert scalar datasets to attributes

True
Source code in src/scida/helpers_hdf5.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def walk_group(obj, tree, get_attrs=False, scalar_to_attr=True):
    """
    Walks through a h5py.Group or zarr.Group object and fills the tree dictionary with
    information about the datasets and groups.

    Parameters
    ----------
    obj: h5py.Group or zarr.Group
        object to walk through
    tree: dict
        dictionary to fill recursively
    get_attrs: bool
        whether to get attributes of each object
    scalar_to_attr: bool
        whether to convert scalar datasets to attributes

    Returns
    -------

    """
    if len(tree) == 0:
        tree.update(**dict(attrs={}, groups=[], datasets=[]))
    if get_attrs and len(obj.attrs) > 0:
        tree["attrs"][obj.name] = dict(obj.attrs)
    if isinstance(obj, (h5py.Dataset, zarr.Array)):
        dtype = get_dtype(obj)
        tree["datasets"].append([obj.name, obj.shape, dtype])
        if scalar_to_attr and len(obj.shape) == 0:
            tree["attrs"][obj.name] = obj[()]
    elif isinstance(obj, (h5py.Group, zarr.Group)):
        tree["groups"].append(obj.name)  # continue the walk
        for o in obj.values():
            walk_group(o, tree, get_attrs=get_attrs)

walk_hdf5file(fn, tree, get_attrs=True)

Walks through a hdf5 file and fills the tree dictionary with information about the datasets and groups.

Parameters:

Name Type Description Default
fn

file path to hdf5 file to walk through

required
tree

dictionary to fill recursively

required
get_attrs

whether to get attributes of each object

True

Returns:

Name Type Description
tree dict

filled dictionary

Source code in src/scida/helpers_hdf5.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def walk_hdf5file(fn, tree, get_attrs=True):
    """
    Walks through a hdf5 file and fills the tree dictionary with
    information about the datasets and groups.

    Parameters
    ----------
    fn: str
        file path to hdf5 file to walk through
    tree: dict
        dictionary to fill recursively
    get_attrs: bool
        whether to get attributes of each object

    Returns
    -------
    tree: dict
        filled dictionary
    """
    with h5py.File(fn, "r") as hf:
        walk_group(hf, tree, get_attrs=get_attrs)
    return tree

walk_zarrfile(fn, tree, get_attrs=True)

Walks through a zarr file and fills the tree dictionary with information about the datasets and groups.

Parameters:

Name Type Description Default
fn

file path to zarr file to walk through

required
tree

dictionary to fill recursively

required
get_attrs

whether to get attributes of each object

True

Returns:

Name Type Description
tree dict

filled dictionary

Source code in src/scida/helpers_hdf5.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def walk_zarrfile(fn, tree, get_attrs=True):
    """
    Walks through a zarr file and fills the tree dictionary with
    information about the datasets and groups.

    Parameters
    ----------
    fn: str
        file path to zarr file to walk through
    tree: dict
        dictionary to fill recursively
    get_attrs: bool
        whether to get attributes of each object

    Returns
    -------
    tree: dict
        filled dictionary
    """
    with zarr.open(fn) as z:
        walk_group(z, tree, get_attrs=get_attrs)
    return tree

helpers_misc

RecursiveNamespace

Bases: SimpleNamespace

A SimpleNamespace that can be created recursively from a dict

Source code in src/scida/helpers_misc.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RecursiveNamespace(types.SimpleNamespace):
    """A SimpleNamespace that can be created recursively from a dict"""

    def __init__(self, **kwargs):
        """
        Create a SimpleNamespace recursively
        """
        super().__init__(**kwargs)
        self.__dict__.update({k: self.__elt(v) for k, v in kwargs.items()})

    def __elt(self, elt):
        """
        Recurse into elt to create leaf namespace objects.
        """
        if isinstance(elt, dict):
            return type(self)(**elt)
        if type(elt) in (list, tuple):
            return [self.__elt(i) for i in elt]
        return elt

__elt(elt)

Recurse into elt to create leaf namespace objects.

Source code in src/scida/helpers_misc.py
42
43
44
45
46
47
48
49
50
def __elt(self, elt):
    """
    Recurse into elt to create leaf namespace objects.
    """
    if isinstance(elt, dict):
        return type(self)(**elt)
    if type(elt) in (list, tuple):
        return [self.__elt(i) for i in elt]
    return elt

__init__(**kwargs)

Create a SimpleNamespace recursively

Source code in src/scida/helpers_misc.py
35
36
37
38
39
40
def __init__(self, **kwargs):
    """
    Create a SimpleNamespace recursively
    """
    super().__init__(**kwargs)
    self.__dict__.update({k: self.__elt(v) for k, v in kwargs.items()})

computedecorator(func)

Decorator introducing compute keyword to evalute dask array returns.

Parameters:

Name Type Description Default
func

Function to decorate

required

Returns:

Type Description
callable

Decorated function

Source code in src/scida/helpers_misc.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def computedecorator(func):
    """
    Decorator introducing compute keyword to evalute dask array returns.

    Parameters
    ----------
    func: callable
        Function to decorate

    Returns
    -------
    callable
        Decorated function
    """

    def wrapper(*args, compute=False, **kwargs):
        res = func(*args, **kwargs)
        if compute:
            return res.compute()
        else:
            return res

    return wrapper

get_args(func)

Get the positional arguments of a function.

Parameters:

Name Type Description Default
func
required

Returns:

Type Description
list

Positional arguments of the function

Source code in src/scida/helpers_misc.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def get_args(func):
    """
    Get the positional arguments of a function.

    Parameters
    ----------
    func: callable

    Returns
    -------
    list
        Positional arguments of the function
    """
    signature = inspect.signature(func)
    return [
        k
        for k, v in signature.parameters.items()
        if v.default is inspect.Parameter.empty
    ]

get_kwargs(func)

Get the keyword arguments of a function.

Parameters:

Name Type Description Default
func

Function to get keyword arguments from

required

Returns:

Type Description
dict

Keyword arguments of the function

Source code in src/scida/helpers_misc.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def get_kwargs(func):
    """
    Get the keyword arguments of a function.

    Parameters
    ----------
    func: callable
        Function to get keyword arguments from

    Returns
    -------
    dict
        Keyword arguments of the function
    """
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }

hash_path(path)

Hash a path to a fixed length string.

Parameters:

Name Type Description Default
path
required

Returns:

Type Description
str
Source code in src/scida/helpers_misc.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def hash_path(path):
    """
    Hash a path to a fixed length string.

    Parameters
    ----------
    path: str or pathlib.Path

    Returns
    -------
    str
    """
    sha = hashlib.sha256()
    sha.update(str(path).strip("/ ").encode())
    return sha.hexdigest()[:16]

make_serializable(v)

Make a value JSON serializable.

Parameters:

Name Type Description Default
v

Object to make JSON serializable

required

Returns:

Type Description
any

JSON serializable object

Source code in src/scida/helpers_misc.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def make_serializable(v):
    """
    Make a value JSON serializable.

    Parameters
    ----------
    v: any
        Object to make JSON serializable

    Returns
    -------
    any
        JSON serializable object
    """

    # Attributes need to be JSON serializable. No numpy types allowed.
    if isinstance(v, np.ndarray):
        v = v.tolist()
    if isinstance(v, np.generic):
        v = v.item()
    if isinstance(v, bytes):
        v = v.decode("utf-8")
    return v

map_blocks(func, *args, name=None, token=None, dtype=None, chunks=None, drop_axis=None, new_axis=None, enforce_ndim=False, meta=None, output_units=None, **kwargs)

map_blocks with units

Source code in src/scida/helpers_misc.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def map_blocks(
    func,
    *args,
    name=None,
    token=None,
    dtype=None,
    chunks=None,
    drop_axis=None,
    new_axis=None,
    enforce_ndim=False,
    meta=None,
    output_units=None,
    **kwargs,
):
    """map_blocks with units"""
    da_kwargs = dict(
        name=name,
        token=token,
        dtype=dtype,
        chunks=chunks,
        drop_axis=drop_axis,
        new_axis=new_axis,
        enforce_ndim=enforce_ndim,
        meta=meta,
    )
    res = da.map_blocks(
        func,
        *args,
        **da_kwargs,
        **kwargs,
    )
    if output_units is not None:
        if hasattr(res, "magnitude"):
            log.info("map_blocks output already has units, overwriting.")
            res = res.magnitude * output_units
        res = res * output_units

    return res

parse_humansize(size)

Parse a human-readable size string to bytes.

Parameters:

Name Type Description Default
size

Human readable size string, e.g. 1.5GiB

required

Returns:

Type Description
int

Size in bytes

Source code in src/scida/helpers_misc.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def parse_humansize(size):
    """
    Parse a human-readable size string to bytes.

    Parameters
    ----------
    size: str
        Human readable size string, e.g. 1.5GiB

    Returns
    -------
    int
        Size in bytes
    """
    size = size.upper()
    if not re.match(r" ", size):
        size = re.sub(r"([KMGT]?I*B)", r" \1", size)
    number, unit = [string.strip() for string in size.split()]
    return int(float(number) * units[unit])

sprint(*args, end='\n', **kwargs)

Print to a string.

Parameters:

Name Type Description Default
args

Arguments to print

()
end

String to append at the end

'\n'
kwargs

Keyword arguments to pass to print

{}

Returns:

Type Description
str

String to print

Source code in src/scida/helpers_misc.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def sprint(*args, end="\n", **kwargs):
    """
    Print to a string.

    Parameters
    ----------
    args: any
        Arguments to print
    end: str
        String to append at the end
    kwargs: any
        Keyword arguments to pass to print

    Returns
    -------
    str
       String to print
    """
    output = io.StringIO()
    print(*args, file=output, end=end, **kwargs)
    contents = output.getvalue()
    output.close()
    return contents

interface

Base dataset class and its handling.

BaseDataset

Base class for all datasets.

Source code in src/scida/interface.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class BaseDataset(metaclass=MixinMeta):
    """
    Base class for all datasets.
    """

    def __init__(
        self,
        path,
        chunksize="auto",
        virtualcache=True,
        overwrite_cache=False,
        fileprefix="",
        hints=None,
        **kwargs
    ):
        """
        Initialize a dataset object.

        Parameters
        ----------
        path: str
            Path to the dataset.
        chunksize: int or str
            Chunksize for dask arrays.
        virtualcache: bool
            Whether to use virtual caching.
        overwrite_cache: bool
            Whether to overwrite existing cache.
        fileprefix: str
            Prefix for files to scan for.
        hints: dict
            Hints for the dataset.
        kwargs: dict
            Additional keyword arguments.
        """
        super().__init__()
        self.hints = hints if hints is not None else {}
        self.path = path
        self.file = None
        # need this 'tempfile' reference to keep garbage collection away for the tempfile
        self.tempfile = None
        self.location = str(path)
        self.chunksize = chunksize
        self.virtualcache = virtualcache
        self.overwrite_cache = overwrite_cache
        self.withunits = kwargs.get("units", False)

        # Let's find the data and metadata for the object at 'path'
        self.metadata = {}
        self._metadata_raw = {}
        self.data = FieldContainer(withunits=self.withunits)

        if not os.path.exists(self.path):
            raise Exception("Specified path '%s' does not exist." % self.path)

        loadkwargs = dict(
            overwrite_cache=self.overwrite_cache,
            fileprefix=fileprefix,
            virtualcache=virtualcache,
            derivedfields_kwargs=dict(snap=self),
            token=self.__dask_tokenize__(),
            withunits=self.withunits,
        )
        if "choose_prefix" in kwargs:
            loadkwargs["choose_prefix"] = kwargs["choose_prefix"]

        res = scida.io.load(path, **loadkwargs)
        self.data = res[0]
        self._metadata_raw = res[1]
        self.file = res[2]
        self.tempfile = res[3]
        self._cached = False

        # any identifying metadata?
        if "dsname" not in self.hints:
            candidates = check_config_for_dataset(self._metadata_raw, path=self.path)
            if len(candidates) > 0:
                dsname = candidates[0]
                log.debug("Dataset is identified as '%s'." % dsname)
                self.hints["dsname"] = dsname

    def _info_custom(self):
        """
        Custom information to be printed by info() method.

        Returns
        -------
        None
        """
        return None

    def info(self, listfields: bool = False):
        """
        Print information about the dataset.

        Parameters
        ----------
        listfields: bool
            If True, list all fields in the dataset.

        Returns
        -------
        None
        """
        rep = ""
        rep += "class: " + sprint(self.__class__.__name__)
        props = self._repr_dict()
        for k, v in props.items():
            rep += sprint("%s: %s" % (k, v))
        if self._info_custom() is not None:
            rep += self._info_custom()
        rep += sprint("=== data ===")
        rep += self.data.info(name="root")
        rep += sprint("============")
        print(rep)

    def _repr_dict(self) -> Dict[str, str]:
        """
        Return a dictionary of properties to be printed by __repr__ method.

        Returns
        -------
        dict
        """
        props = dict()
        props["source"] = self.path
        return props

    def __repr__(self) -> str:
        """
        Return a string representation of the object.

        Returns
        -------
        str
        """
        props = self._repr_dict()
        clsname = self.__class__.__name__
        result = clsname + "["
        for k, v in props.items():
            result += "%s=%s, " % (k, v)
        result = result[:-2] + "]"
        return result

    def _repr_pretty_(self, p, cycle):
        """
        Pretty print representation for IPython.
        Parameters
        ----------
        p
        cycle

        Returns
        -------
        None
        """
        rpr = self.__repr__()
        p.text(rpr)

    def __init_subclass__(cls, *args, **kwargs):
        """
        Register subclasses in the dataset type registry.
        Parameters
        ----------
        args: list
        kwargs: dict

        Returns
        -------
        None
        """
        super().__init_subclass__(*args, **kwargs)
        if cls.__name__ == "Delay":
            return  # nothing to register for Delay objects
        if "Mixin" in cls.__name__:
            return  # do not register classes with Mixins
        dataset_type_registry[cls.__name__] = cls

    @classmethod
    @abc.abstractmethod
    def validate_path(cls, path, *args, **kwargs):
        """
        Validate whether the given path is a valid path for this dataset.
        Parameters
        ----------
        path
        args
        kwargs

        Returns
        -------
        bool

        """
        return False

    def __hash__(self) -> int:
        """
        Hash for Dataset instance to be derived from the file location.

        Returns
        -------
        int
        """
        # determinstic hash; note that hash() on a string is no longer deterministic in python3.
        hash_value = (
            int(hashlib.sha256(self.location.encode("utf-8")).hexdigest(), 16)
            % 10**10
        )
        return hash_value

    def __getitem__(self, item):
        return self.data[item]

    def __dask_tokenize__(self) -> int:
        """
        Token for dask to be derived -- naively from the file location.

        Returns
        -------
        int
        """
        return self.__hash__()

    def return_data(self) -> FieldContainer:
        """
        Return the data container.

        Returns
        -------
        FieldContainer
        """
        return self.data

    def save(
        self,
        fname,
        fields: Union[
            str, Dict[str, Union[List[str], Dict[str, da.Array]]], FieldContainer
        ] = "all",
        overwrite: bool = True,
        zarr_kwargs: Optional[dict] = None,
        cast_uints: bool = False,
        extra_attrs: Optional[dict] = None,
    ) -> None:
        """
        Save the dataset to a file using the 'zarr' format.
        Parameters
        ----------
        extra_attrs: dict
            additional attributes to save in the root group
        fname: str
            Filename to save to.
        fields: str or dict
            dictionary of dask arrays to save. If equal to 'all', save all fields in current dataset.
        overwrite
            overwrite existing file
        zarr_kwargs
            optional arguments to pass to zarr
        cast_uints
            need to potentially cast uints to ints for some compressions; TODO: clean this up

        Returns
        -------
        None
        """
        # We use zarr, as this way we have support to directly write into the file by the workers
        # (rather than passing back the data chunk over the scheduler to the interface)
        # Also, this way we can leverage new features, such as a large variety of compression methods.
        # cast_uints: if true, we cast uints to ints; needed for some compressions (particularly zfp)
        if zarr_kwargs is None:
            zarr_kwargs = {}
        store = zarr.DirectoryStore(fname, **zarr_kwargs)
        root = zarr.group(store, overwrite=overwrite)

        # Metadata
        defaultattributes = ["Config", "Header", "Parameters"]
        for dctname in defaultattributes:
            if dctname.lower() in self.__dict__:
                grp = root.create_group(dctname)
                dct = self.__dict__[dctname.lower()]
                for k, v in dct.items():
                    v = make_serializable(v)
                    grp.attrs[k] = v
        if extra_attrs is not None:
            for k, v in extra_attrs.items():
                root.attrs[k] = v
        # Data
        tasks = []
        ptypes = self.data.keys()
        if isinstance(fields, dict):
            ptypes = fields.keys()
        elif isinstance(fields, str):
            if not fields == "all":
                raise ValueError("Invalid field specifier.")
        else:
            raise ValueError("Invalid type for fields.")
        for p in ptypes:
            root.create_group(p)
            if fields == "all":
                fieldkeys = self.data[p]
            else:
                if isinstance(fields[p], dict):
                    fieldkeys = fields[p].keys()
                else:
                    fieldkeys = fields[p]
            for k in fieldkeys:
                if not isinstance(fields, str) and isinstance(fields[p], dict):
                    arr = fields[p][k]
                else:
                    arr = self.data[p][k]
                if hasattr(arr, "magnitude"):  # if we have units, remove those here
                    # TODO: save units in metadata!
                    arr = arr.magnitude
                if np.any(np.isnan(arr.shape)):
                    arr.compute_chunk_sizes()  # very inefficient (have to do it separately for every array)
                    arr = arr.rechunk(chunks="auto")
                if cast_uints:
                    if arr.dtype == np.uint64:
                        arr = arr.astype(np.int64)
                    elif arr.dtype == np.uint32:
                        arr = arr.astype(np.int32)
                task = da.to_zarr(
                    arr, os.path.join(fname, p, k), overwrite=True, compute=False
                )
                tasks.append(task)
        dask.compute(tasks)

__dask_tokenize__()

Token for dask to be derived -- naively from the file location.

Returns:

Type Description
int
Source code in src/scida/interface.py
251
252
253
254
255
256
257
258
259
def __dask_tokenize__(self) -> int:
    """
    Token for dask to be derived -- naively from the file location.

    Returns
    -------
    int
    """
    return self.__hash__()

__hash__()

Hash for Dataset instance to be derived from the file location.

Returns:

Type Description
int
Source code in src/scida/interface.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def __hash__(self) -> int:
    """
    Hash for Dataset instance to be derived from the file location.

    Returns
    -------
    int
    """
    # determinstic hash; note that hash() on a string is no longer deterministic in python3.
    hash_value = (
        int(hashlib.sha256(self.location.encode("utf-8")).hexdigest(), 16)
        % 10**10
    )
    return hash_value

__init__(path, chunksize='auto', virtualcache=True, overwrite_cache=False, fileprefix='', hints=None, **kwargs)

Initialize a dataset object.

Parameters:

Name Type Description Default
path

Path to the dataset.

required
chunksize

Chunksize for dask arrays.

'auto'
virtualcache

Whether to use virtual caching.

True
overwrite_cache

Whether to overwrite existing cache.

False
fileprefix

Prefix for files to scan for.

''
hints

Hints for the dataset.

None
kwargs

Additional keyword arguments.

{}
Source code in src/scida/interface.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(
    self,
    path,
    chunksize="auto",
    virtualcache=True,
    overwrite_cache=False,
    fileprefix="",
    hints=None,
    **kwargs
):
    """
    Initialize a dataset object.

    Parameters
    ----------
    path: str
        Path to the dataset.
    chunksize: int or str
        Chunksize for dask arrays.
    virtualcache: bool
        Whether to use virtual caching.
    overwrite_cache: bool
        Whether to overwrite existing cache.
    fileprefix: str
        Prefix for files to scan for.
    hints: dict
        Hints for the dataset.
    kwargs: dict
        Additional keyword arguments.
    """
    super().__init__()
    self.hints = hints if hints is not None else {}
    self.path = path
    self.file = None
    # need this 'tempfile' reference to keep garbage collection away for the tempfile
    self.tempfile = None
    self.location = str(path)
    self.chunksize = chunksize
    self.virtualcache = virtualcache
    self.overwrite_cache = overwrite_cache
    self.withunits = kwargs.get("units", False)

    # Let's find the data and metadata for the object at 'path'
    self.metadata = {}
    self._metadata_raw = {}
    self.data = FieldContainer(withunits=self.withunits)

    if not os.path.exists(self.path):
        raise Exception("Specified path '%s' does not exist." % self.path)

    loadkwargs = dict(
        overwrite_cache=self.overwrite_cache,
        fileprefix=fileprefix,
        virtualcache=virtualcache,
        derivedfields_kwargs=dict(snap=self),
        token=self.__dask_tokenize__(),
        withunits=self.withunits,
    )
    if "choose_prefix" in kwargs:
        loadkwargs["choose_prefix"] = kwargs["choose_prefix"]

    res = scida.io.load(path, **loadkwargs)
    self.data = res[0]
    self._metadata_raw = res[1]
    self.file = res[2]
    self.tempfile = res[3]
    self._cached = False

    # any identifying metadata?
    if "dsname" not in self.hints:
        candidates = check_config_for_dataset(self._metadata_raw, path=self.path)
        if len(candidates) > 0:
            dsname = candidates[0]
            log.debug("Dataset is identified as '%s'." % dsname)
            self.hints["dsname"] = dsname

__init_subclass__(*args, **kwargs)

Register subclasses in the dataset type registry.

Parameters:

Name Type Description Default
args
()
kwargs
{}

Returns:

Type Description
None
Source code in src/scida/interface.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def __init_subclass__(cls, *args, **kwargs):
    """
    Register subclasses in the dataset type registry.
    Parameters
    ----------
    args: list
    kwargs: dict

    Returns
    -------
    None
    """
    super().__init_subclass__(*args, **kwargs)
    if cls.__name__ == "Delay":
        return  # nothing to register for Delay objects
    if "Mixin" in cls.__name__:
        return  # do not register classes with Mixins
    dataset_type_registry[cls.__name__] = cls

__repr__()

Return a string representation of the object.

Returns:

Type Description
str
Source code in src/scida/interface.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def __repr__(self) -> str:
    """
    Return a string representation of the object.

    Returns
    -------
    str
    """
    props = self._repr_dict()
    clsname = self.__class__.__name__
    result = clsname + "["
    for k, v in props.items():
        result += "%s=%s, " % (k, v)
    result = result[:-2] + "]"
    return result

info(listfields=False)

Print information about the dataset.

Parameters:

Name Type Description Default
listfields bool

If True, list all fields in the dataset.

False

Returns:

Type Description
None
Source code in src/scida/interface.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def info(self, listfields: bool = False):
    """
    Print information about the dataset.

    Parameters
    ----------
    listfields: bool
        If True, list all fields in the dataset.

    Returns
    -------
    None
    """
    rep = ""
    rep += "class: " + sprint(self.__class__.__name__)
    props = self._repr_dict()
    for k, v in props.items():
        rep += sprint("%s: %s" % (k, v))
    if self._info_custom() is not None:
        rep += self._info_custom()
    rep += sprint("=== data ===")
    rep += self.data.info(name="root")
    rep += sprint("============")
    print(rep)

return_data()

Return the data container.

Returns:

Type Description
FieldContainer
Source code in src/scida/interface.py
261
262
263
264
265
266
267
268
269
def return_data(self) -> FieldContainer:
    """
    Return the data container.

    Returns
    -------
    FieldContainer
    """
    return self.data

save(fname, fields='all', overwrite=True, zarr_kwargs=None, cast_uints=False, extra_attrs=None)

Save the dataset to a file using the 'zarr' format.

Parameters:

Name Type Description Default
extra_attrs Optional[dict]

additional attributes to save in the root group

None
fname

Filename to save to.

required
fields Union[str, Dict[str, Union[List[str], Dict[str, Array]]], FieldContainer]

dictionary of dask arrays to save. If equal to 'all', save all fields in current dataset.

'all'
overwrite bool

overwrite existing file

True
zarr_kwargs Optional[dict]

optional arguments to pass to zarr

None
cast_uints bool

need to potentially cast uints to ints for some compressions; TODO: clean this up

False

Returns:

Type Description
None
Source code in src/scida/interface.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def save(
    self,
    fname,
    fields: Union[
        str, Dict[str, Union[List[str], Dict[str, da.Array]]], FieldContainer
    ] = "all",
    overwrite: bool = True,
    zarr_kwargs: Optional[dict] = None,
    cast_uints: bool = False,
    extra_attrs: Optional[dict] = None,
) -> None:
    """
    Save the dataset to a file using the 'zarr' format.
    Parameters
    ----------
    extra_attrs: dict
        additional attributes to save in the root group
    fname: str
        Filename to save to.
    fields: str or dict
        dictionary of dask arrays to save. If equal to 'all', save all fields in current dataset.
    overwrite
        overwrite existing file
    zarr_kwargs
        optional arguments to pass to zarr
    cast_uints
        need to potentially cast uints to ints for some compressions; TODO: clean this up

    Returns
    -------
    None
    """
    # We use zarr, as this way we have support to directly write into the file by the workers
    # (rather than passing back the data chunk over the scheduler to the interface)
    # Also, this way we can leverage new features, such as a large variety of compression methods.
    # cast_uints: if true, we cast uints to ints; needed for some compressions (particularly zfp)
    if zarr_kwargs is None:
        zarr_kwargs = {}
    store = zarr.DirectoryStore(fname, **zarr_kwargs)
    root = zarr.group(store, overwrite=overwrite)

    # Metadata
    defaultattributes = ["Config", "Header", "Parameters"]
    for dctname in defaultattributes:
        if dctname.lower() in self.__dict__:
            grp = root.create_group(dctname)
            dct = self.__dict__[dctname.lower()]
            for k, v in dct.items():
                v = make_serializable(v)
                grp.attrs[k] = v
    if extra_attrs is not None:
        for k, v in extra_attrs.items():
            root.attrs[k] = v
    # Data
    tasks = []
    ptypes = self.data.keys()
    if isinstance(fields, dict):
        ptypes = fields.keys()
    elif isinstance(fields, str):
        if not fields == "all":
            raise ValueError("Invalid field specifier.")
    else:
        raise ValueError("Invalid type for fields.")
    for p in ptypes:
        root.create_group(p)
        if fields == "all":
            fieldkeys = self.data[p]
        else:
            if isinstance(fields[p], dict):
                fieldkeys = fields[p].keys()
            else:
                fieldkeys = fields[p]
        for k in fieldkeys:
            if not isinstance(fields, str) and isinstance(fields[p], dict):
                arr = fields[p][k]
            else:
                arr = self.data[p][k]
            if hasattr(arr, "magnitude"):  # if we have units, remove those here
                # TODO: save units in metadata!
                arr = arr.magnitude
            if np.any(np.isnan(arr.shape)):
                arr.compute_chunk_sizes()  # very inefficient (have to do it separately for every array)
                arr = arr.rechunk(chunks="auto")
            if cast_uints:
                if arr.dtype == np.uint64:
                    arr = arr.astype(np.int64)
                elif arr.dtype == np.uint32:
                    arr = arr.astype(np.int32)
            task = da.to_zarr(
                arr, os.path.join(fname, p, k), overwrite=True, compute=False
            )
            tasks.append(task)
    dask.compute(tasks)

validate_path(path, *args, **kwargs) abstractmethod classmethod

Validate whether the given path is a valid path for this dataset.

Parameters:

Name Type Description Default
path
required
args
()
kwargs
{}

Returns:

Type Description
bool
Source code in src/scida/interface.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
@classmethod
@abc.abstractmethod
def validate_path(cls, path, *args, **kwargs):
    """
    Validate whether the given path is a valid path for this dataset.
    Parameters
    ----------
    path
    args
    kwargs

    Returns
    -------
    bool

    """
    return False

Dataset

Bases: BaseDataset

Base class for datasets with some functions to be overwritten by subclass.

Source code in src/scida/interface.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class Dataset(BaseDataset):
    """
    Base class for datasets with some functions to be overwritten by subclass.
    """

    @classmethod
    def validate_path(cls, path, *args, **kwargs):
        """
        Validate whether the given path is a valid path for this dataset.

        Parameters
        ----------
        path: str
            Path to the dataset.
        args: list
        kwargs: dict

        Returns
        -------
        bool
        """
        return True

    @classmethod
    def _clean_metadata_from_raw(cls, rawmetadata):
        """Clean metadata from raw metadata"""
        return {}

validate_path(path, *args, **kwargs) classmethod

Validate whether the given path is a valid path for this dataset.

Parameters:

Name Type Description Default
path

Path to the dataset.

required
args
()
kwargs
{}

Returns:

Type Description
bool
Source code in src/scida/interface.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@classmethod
def validate_path(cls, path, *args, **kwargs):
    """
    Validate whether the given path is a valid path for this dataset.

    Parameters
    ----------
    path: str
        Path to the dataset.
    args: list
    kwargs: dict

    Returns
    -------
    bool
    """
    return True

DatasetWithUnitMixin

Bases: UnitMixin, Dataset

Dataset with units.

Source code in src/scida/interface.py
396
397
398
399
400
401
402
403
class DatasetWithUnitMixin(UnitMixin, Dataset):
    """
    Dataset with units.
    """

    def __init__(self, *args, **kwargs):
        """Initialize dataset with units."""
        super().__init__(*args, **kwargs)

__init__(*args, **kwargs)

Initialize dataset with units.

Source code in src/scida/interface.py
401
402
403
def __init__(self, *args, **kwargs):
    """Initialize dataset with units."""
    super().__init__(*args, **kwargs)

MixinMeta

Bases: type

Metaclass for Mixin classes.

Source code in src/scida/interface.py
26
27
28
29
30
31
32
33
34
class MixinMeta(type):
    """
    Metaclass for Mixin classes.
    """

    def __call__(cls, *args, **kwargs):
        mixins = kwargs.pop("mixins", None)
        newcls = create_datasetclass_with_mixins(cls, mixins)
        return type.__call__(newcls, *args, **kwargs)

Selector

Bases: object

Base Class for data selection decorator factory

Source code in src/scida/interface.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
class Selector(object):
    """Base Class for data selection decorator factory"""

    def __init__(self):
        """
        Initialize the selector.
        """
        self.keys = None  # the keys we check for.
        # holds a copy of the species' fields
        self.data_backup = FieldContainer()
        # holds the species' fields we operate on
        self.data: FieldContainer = FieldContainer()

    def __call__(self, fn, *args, **kwargs):
        """
        Call the selector.

        Parameters
        ----------
        fn: function
            Function to be decorated.
        args: list
        kwargs: dict

        Returns
        -------
        function
            Decorated function.

        """

        def newfn(*args, **kwargs):
            # TODO: Add graceful exit/restore after exception in self.prepare
            self.data_backup = args[0].data
            self.data = args[0].data.copy_skeleton()
            # deepdictkeycopy(self.data_backup, self.data)

            self.prepare(*args, **kwargs)
            if self.keys is None:
                raise NotImplementedError(
                    "Subclass implementation needed for self.keys!"
                )
            if kwargs.pop("dropkeys", True):
                for k in self.keys:
                    kwargs.pop(k, None)
            try:
                result = fn(*args, **kwargs)
                return result
            finally:
                self.finalize(*args, **kwargs)

        return newfn

    def prepare(self, *args, **kwargs) -> None:
        """
        Prepare the data for selection. To be implemented in subclasses.

        Parameters
        ----------
        args
        kwargs

        Returns
        -------

        """
        raise NotImplementedError("Subclass implementation needed!")

    def finalize(self, *args, **kwargs) -> None:
        """
        Finalize the data after selection. To be implemented in subclasses.

        Parameters
        ----------
        args
        kwargs

        Returns
        -------

        """
        args[0].data = self.data_backup

__call__(fn, *args, **kwargs)

Call the selector.

Parameters:

Name Type Description Default
fn

Function to be decorated.

required
args
()
kwargs
{}

Returns:

Type Description
function

Decorated function.

Source code in src/scida/interface.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def __call__(self, fn, *args, **kwargs):
    """
    Call the selector.

    Parameters
    ----------
    fn: function
        Function to be decorated.
    args: list
    kwargs: dict

    Returns
    -------
    function
        Decorated function.

    """

    def newfn(*args, **kwargs):
        # TODO: Add graceful exit/restore after exception in self.prepare
        self.data_backup = args[0].data
        self.data = args[0].data.copy_skeleton()
        # deepdictkeycopy(self.data_backup, self.data)

        self.prepare(*args, **kwargs)
        if self.keys is None:
            raise NotImplementedError(
                "Subclass implementation needed for self.keys!"
            )
        if kwargs.pop("dropkeys", True):
            for k in self.keys:
                kwargs.pop(k, None)
        try:
            result = fn(*args, **kwargs)
            return result
        finally:
            self.finalize(*args, **kwargs)

    return newfn

__init__()

Initialize the selector.

Source code in src/scida/interface.py
409
410
411
412
413
414
415
416
417
def __init__(self):
    """
    Initialize the selector.
    """
    self.keys = None  # the keys we check for.
    # holds a copy of the species' fields
    self.data_backup = FieldContainer()
    # holds the species' fields we operate on
    self.data: FieldContainer = FieldContainer()

finalize(*args, **kwargs)

Finalize the data after selection. To be implemented in subclasses.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/interface.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def finalize(self, *args, **kwargs) -> None:
    """
    Finalize the data after selection. To be implemented in subclasses.

    Parameters
    ----------
    args
    kwargs

    Returns
    -------

    """
    args[0].data = self.data_backup

prepare(*args, **kwargs)

Prepare the data for selection. To be implemented in subclasses.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/interface.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def prepare(self, *args, **kwargs) -> None:
    """
    Prepare the data for selection. To be implemented in subclasses.

    Parameters
    ----------
    args
    kwargs

    Returns
    -------

    """
    raise NotImplementedError("Subclass implementation needed!")

create_datasetclass_with_mixins(cls, mixins)

Create a new class from a given class and a list of mixins.

Parameters:

Name Type Description Default
cls

dataset class to be extended

required
mixins Optional[List]

list of mixin classes to be added

required

Returns:

Type Description
Type[BaseDataset]

new class with mixins

Source code in src/scida/interface.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def create_datasetclass_with_mixins(cls, mixins: Optional[List]):
    """
    Create a new class from a given class and a list of mixins.

    Parameters
    ----------
    cls:
        dataset class to be extended
    mixins:
        list of mixin classes to be added

    Returns
    -------
    Type[BaseDataset]
        new class with mixins
    """
    newcls = cls
    if isinstance(mixins, list) and len(mixins) > 0:
        name = cls.__name__ + "With" + "And".join([m.__name__ for m in mixins])
        # adjust entry point if __init__ available in some mixin
        nms = dict(cls.__dict__)
        # need to make sure first mixin init is called over cls init
        nms["__init__"] = mixins[0].__init__
        newcls = type(name, (*mixins, cls), nms)
    return newcls

io

scida.io

fits

FITS file reader for scida

fitsrecords_to_daskarrays(fitsrecords)

Convert a FITS record array to a dictionary of dask arrays.

Parameters:

Name Type Description Default
fitsrecords

FITS record array

required

Returns:

Type Description
dict

dictionary of dask arrays

Source code in src/scida/io/fits.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def fitsrecords_to_daskarrays(fitsrecords):
    """
    Convert a FITS record array to a dictionary of dask arrays.
    Parameters
    ----------
    fitsrecords: np.ndarray
        FITS record array

    Returns
    -------
    dict
        dictionary of dask arrays

    """
    load_arr = delayed(lambda slc, field: fitsrecords[slc][field])
    shape = fitsrecords.shape
    darrdict = {}
    csize = dask.config.get("array.chunk-size")

    csize = parse_size(csize)  # need int

    nbytes_dtype_max = 1
    for fieldname in fitsrecords.dtype.names:
        nbytes_dtype = fitsrecords.dtype[fieldname].itemsize
        nbytes_dtype_max = max(nbytes_dtype_max, nbytes_dtype)
    chunksize = csize // nbytes_dtype_max

    for fieldname in fitsrecords.dtype.names:
        chunks = []
        for index in range(0, shape[-1], chunksize):
            dtype = fitsrecords.dtype[fieldname]
            chunk_size = min(chunksize, shape[-1] - index)
            slc = slice(index, index + chunk_size)
            shp = (chunk_size,)
            if dtype.subdtype is not None:
                # for now, we expect this to be void type
                assert dtype.type is np.void
                break  # do not handle void type for now => skip field
                # shp = shp + dtype.subdtype[0].shape
                # dtype = dtype.subdtype[0].base
            chunk = da.from_delayed(load_arr(slc, fieldname), shape=shp, dtype=dtype)
            chunks.append(chunk)
        if len(chunks) > 0:
            darrdict[fieldname] = da.concatenate(chunks, axis=0)
    return darrdict

misc

Miscellaneous helper functions.

check_config_for_dataset(metadata, path=None, unique=True)

Check whether the given dataset can be identified to be a certain simulation (type) by its metadata.

Parameters:

Name Type Description Default
metadata

metadata of the dataset used for identification

required
path Optional[str]

path to the dataset, sometimes helpful for identification

None
unique bool

whether to expect return to be unique

True

Returns:

Type Description
list

candidates

Source code in src/scida/misc.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def check_config_for_dataset(metadata, path: Optional[str] = None, unique: bool = True):
    """
    Check whether the given dataset can be identified to be a certain simulation (type) by its metadata.

    Parameters
    ----------
    metadata: dict
        metadata of the dataset used for identification
    path: str
        path to the dataset, sometimes helpful for identification
    unique: bool
        whether to expect return to be unique

    Returns
    -------
    list
        candidates

    """
    c = get_simulationconfig()

    candidates = []
    if "data" not in c:
        return candidates
    simdct = c["data"]
    if simdct is None:
        simdct = {}
    for k, vals in simdct.items():
        if vals is None:
            continue
        possible_candidate = True
        if "identifiers" in vals:
            idtfrs = vals["identifiers"]
            # special key not specifying identifying metadata
            specialkeys = ["name_contains"]
            allkeys = idtfrs.keys()
            keys = list([k for k in allkeys if k not in specialkeys])
            if "name_contains" in idtfrs and path is not None:
                p = pathlib.Path(path)
                # we only check the last three path elements
                dirnames = [p.name, p.parents[0].name, p.parents[1].name]
                substring = idtfrs["name_contains"]
                if not any([substring.lower() in d.lower() for d in dirnames]):
                    possible_candidate = False
            if len(allkeys) == 0:
                possible_candidate = False
            for grp in keys:
                v = idtfrs[grp]
                h5path = "/" + grp
                if h5path not in metadata:
                    possible_candidate = False
                    break
                attrs = metadata[h5path]
                for ikey, ival in v.items():
                    if ikey not in attrs:
                        possible_candidate = False
                        break
                    av = attrs[ikey]
                    matchtype = None
                    if isinstance(ival, dict):
                        matchtype = ival.get("match", matchtype)  # default means equal
                        ival = ival["content"]

                    if isinstance(av, bytes):
                        av = av.decode("UTF-8")
                    if matchtype is None:
                        if av != ival:
                            possible_candidate = False
                            break
                    elif matchtype == "substring":
                        if ival not in av:
                            possible_candidate = False
                            break
        else:
            possible_candidate = False
        if possible_candidate:
            candidates.append(k)
    if unique and len(candidates) > 1:
        raise ValueError("Multiple dataset candidates (set unique=False?):", candidates)
    return candidates

deepdictkeycopy(olddict, newdict)

Recursively walk nested dictionary, only creating empty dictionaries for entries that are dictionaries themselves.

Parameters:

Name Type Description Default
olddict object
required
newdict object
required

Returns:

Type Description
None
Source code in src/scida/misc.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def deepdictkeycopy(olddict: object, newdict: object) -> None:
    """
    Recursively walk nested dictionary, only creating empty dictionaries for entries that are dictionaries themselves.
    Parameters
    ----------
    olddict
    newdict

    Returns
    -------
    None
    """
    cls = olddict.__class__
    for k, v in olddict.items():
        if isinstance(v, MutableMapping):
            newdict[k] = cls()
            deepdictkeycopy(v, newdict[k])

get_container_from_path(element, container=None, create_missing=False)

Get a container from a path.

Parameters:

Name Type Description Default
element str
required
container FieldContainer
None
create_missing bool
False

Returns:

Name Type Description
FieldContainer FieldContainer

container specified by path

Source code in src/scida/misc.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def get_container_from_path(
    element: str, container: FieldContainer = None, create_missing: bool = False
) -> FieldContainer:
    """
    Get a container from a path.
    Parameters
    ----------
    element: str
    container: FieldContainer
    create_missing: bool

    Returns
    -------
    FieldContainer:
        container specified by path

    """
    keys = element.split("/")
    rv = container
    for key in keys:
        if key == "":
            continue
        if key not in rv._containers:
            if not create_missing:
                raise ValueError("Container '%s' not found in '%s'" % (key, rv))
            rv.add_container(key, name=key)
        rv = rv._containers[key]
    return rv

map_interface_args(paths, *args, **kwargs)

Map arguments for interface if they are not lists.

Parameters:

Name Type Description Default
paths list
required
args
()
kwargs
{}

Returns:

Type Description
generator

yields path, args, kwargs

Source code in src/scida/misc.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def map_interface_args(paths: list, *args, **kwargs):
    """
    Map arguments for interface if they are not lists.
    Parameters
    ----------
    paths
    args
    kwargs

    Returns
    -------
    generator
        yields path, args, kwargs

    """
    n = len(paths)
    for i, path in enumerate(paths):
        targs = []
        for arg in args:
            if not (isinstance(arg, list)) or len(arg) != n:
                targs.append(arg)
            else:
                targs.append(arg[i])
        tkwargs = {}
        for k, v in kwargs.items():
            if not (isinstance(v, list)) or len(v) != n:
                tkwargs[k] = v
            else:
                tkwargs[k] = v[i]
        yield path, targs, tkwargs

parse_size(size)

Parse a size string to a number in bytes.

Parameters:

Name Type Description Default
size
required

Returns:

Type Description
int

size in bytes

Source code in src/scida/misc.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def parse_size(size):
    """
    Parse a size string to a number in bytes.
    Parameters
    ----------
    size: str

    Returns
    -------
    int
        size in bytes
    """
    idx = 0
    for c in size:
        if c.isnumeric():
            continue
        idx += 1
    number = size[:idx]
    unit = size[idx:]
    return int(float(number) * _sizeunits[unit.lower().strip()])

path_hdf5cachefile_exists(path, **kwargs)

Checks whether a cache file exists for given path.

Parameters:

Name Type Description Default
path

path to the dataset

required
kwargs

passed to return_hdf5cachepath

{}

Returns:

Type Description
bool
Source code in src/scida/misc.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def path_hdf5cachefile_exists(path, **kwargs) -> bool:
    """
    Checks whether a cache file exists for given path.
    Parameters
    ----------
    path:
        path to the dataset
    kwargs:
        passed to return_hdf5cachepath
    Returns
    -------
    bool

    """
    fp = return_hdf5cachepath(path, **kwargs)
    if os.path.isfile(fp):
        return True
    return False

rectangular_cutout_mask(center, width, coords, pbc=True, boxsize=None, backend='dask', chunksize='auto')

Create a rectangular mask for a given set of coordinates.

Parameters:

Name Type Description Default
center

center of the rectangle

required
width

widths of the rectangle

required
coords

coordinates to mask

required
pbc

whether to apply PBC

True
boxsize

boxsize for PBC

None
backend

backend to use (dask or numpy)

'dask'
chunksize

chunksize for dask

'auto'

Returns:

Name Type Description
ndarray

mask

Source code in src/scida/misc.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def rectangular_cutout_mask(
    center, width, coords, pbc=True, boxsize=None, backend="dask", chunksize="auto"
):
    """
    Create a rectangular mask for a given set of coordinates.
    Parameters
    ----------
    center: list
        center of the rectangle
    width: list
        widths of the rectangle
    coords: ndarray
        coordinates to mask
    pbc: bool
        whether to apply PBC
    boxsize:
        boxsize for PBC
    backend: str
        backend to use (dask or numpy)
    chunksize: str
        chunksize for dask

    Returns
    -------
    ndarray:
        mask

    """
    center = np.array(center)
    width = np.array(width)
    if backend == "dask":
        be = da
    else:
        be = np

    dists = coords - center
    dists = be.fabs(dists)
    if pbc:
        if boxsize is None:
            raise ValueError("Need to specify for boxsize for PBC.")
        dists = be.where(dists > 0.5 * boxsize, be.fabs(boxsize - dists), dists)

    kwargs = {}
    if backend == "dask":
        kwargs["chunks"] = chunksize
    mask = be.ones(coords.shape[0], dtype=np.bool, **kwargs)
    for i in range(3):
        mask &= dists[:, i] < (
            width[i] / 2.0
        )  # TODO: This interval is not closed on the left side.
    return mask

return_cachefile_path(fname)

Return the path to the cache file, return None if path cannot be generated.

Parameters:

Name Type Description Default
fname str

filename of cache file

required

Returns:

Type Description
str or None
Source code in src/scida/misc.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def return_cachefile_path(fname: str) -> Optional[str]:
    """
    Return the path to the cache file, return None if path cannot be generated.

    Parameters
    ----------
    fname: str
        filename of cache file

    Returns
    -------
    str or None

    """
    config = get_config()
    if "cache_path" not in config:
        return None
    cp = config["cache_path"]
    cp = os.path.expanduser(cp)
    path = pathlib.Path(cp)
    path.mkdir(parents=True, exist_ok=True)
    fp = os.path.join(cp, fname)
    fp = os.path.expanduser(fp)
    bp = os.path.dirname(fp)
    if not os.path.exists(bp):
        try:
            os.mkdir(bp)
        except FileExistsError:
            pass  # can happen due to parallel access
    return fp

return_hdf5cachepath(path, fileprefix=None)

Returns the path to the cache file for a given path.

Parameters:

Name Type Description Default
path

path to the dataset

required
fileprefix Optional[str]

Can be used to specify the fileprefix used for the dataset.

None

Returns:

Type Description
str
Source code in src/scida/misc.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def return_hdf5cachepath(path, fileprefix: Optional[str] = None) -> str:
    """
    Returns the path to the cache file for a given path.

    Parameters
    ----------
    path: str
        path to the dataset
    fileprefix: Optional[str]
        Can be used to specify the fileprefix used for the dataset.

    Returns
    -------
    str

    """
    if fileprefix is not None:
        path = os.path.join(path, fileprefix)
    hsh = hash_path(path)
    fp = return_cachefile_path(os.path.join(hsh, "data.hdf5"))
    return fp

str_is_float(element)

Check whether a string can be converted to a float.

Parameters:

Name Type Description Default
element str

string to check

required

Returns:

Type Description
bool
Source code in src/scida/misc.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def str_is_float(element: str) -> bool:
    """
    Check whether a string can be converted to a float.
    Parameters
    ----------
    element: str
        string to check

    Returns
    -------
    bool

    """
    try:
        float(element)
        return True
    except ValueError:
        return False

registries

This module contains registries for dataset and dataseries subclasses. Subclasses are automatically registered through init_subclass

series

This module contains the base class for DataSeries, which is a container for collections of dataset instances.

DatasetSeries

Bases: object

A container for collections of dataset instances

Attributes:

Name Type Description
datasets list

list of dataset instances

paths list

list of paths to data

names list

list of names for datasets

hash str

hash of the object, constructed from dataset paths.

Source code in src/scida/series.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
class DatasetSeries(object):
    """A container for collections of dataset instances

    Attributes
    ----------
    datasets: list
        list of dataset instances
    paths: list
        list of paths to data
    names: list
        list of names for datasets
    hash: str
        hash of the object, constructed from dataset paths.
    """

    def __init__(
        self,
        paths: Union[List[str], List[Path]],
        *interface_args,
        datasetclass=None,
        overwrite_cache=False,
        lazy=True,  # lazy will only initialize data sets on demand.
        names=None,
        **interface_kwargs
    ):
        """

        Parameters
        ----------
        paths: list
            list of paths to data
        interface_args:
            arguments to pass to interface class
        datasetclass:
            class to use for dataset instances
        overwrite_cache:
            whether to overwrite existing cache
        lazy:
            whether to initialize datasets lazily
        names:
            names for datasets
        interface_kwargs:
            keyword arguments to pass to interface class
        """
        self.paths = paths
        self.names = names
        self.hash = hash_path("".join([str(p) for p in paths]))
        self._metadata = None
        self._metadatafile = return_cachefile_path(os.path.join(self.hash, "data.json"))
        self.lazy = lazy
        if overwrite_cache and os.path.exists(self._metadatafile):
            os.remove(self._metadatafile)
        for p in paths:
            if not (isinstance(p, Path)):
                p = Path(p)
            if not (p.exists()):
                raise ValueError("Specified path '%s' does not exist." % p)
        dec = delay_init  # lazy loading

        # Catch Mixins and create type:
        ikw = dict(overwrite_cache=overwrite_cache)
        ikw.update(**interface_kwargs)
        mixins = ikw.pop("mixins", [])
        datasetclass = create_datasetclass_with_mixins(datasetclass, mixins)
        self._dataset_cls = datasetclass

        gen = map_interface_args(paths, *interface_args, **ikw)
        self.datasets = [dec(datasetclass)(p, *a, **kw) for p, a, kw in gen]

        if self.metadata is None:
            print("Have not cached this data series. Can take a while.")
            dct = {}
            for i, (path, d) in enumerate(
                tqdm(zip(self.paths, self.datasets), total=len(self.paths))
            ):
                rawmeta = load_metadata(
                    path, choose_prefix=True, use_cachefile=not (overwrite_cache)
                )
                # class method does not initiate obj.
                dct[i] = d._clean_metadata_from_raw(rawmeta)
            self.metadata = dct

    def __init_subclass__(cls, *args, **kwargs):
        """
        Register datasetseries subclass in registry.
        Parameters
        ----------
        args:
            (unused)
        kwargs:
            (unused)
        Returns
        -------
        None
        """
        super().__init_subclass__(*args, **kwargs)
        dataseries_type_registry[cls.__name__] = cls

    def __len__(self):
        """Return number of datasets in series.

        Returns
        -------
        int
        """
        return len(self.datasets)

    def __getitem__(self, key):
        """
        Return dataset by index.
        Parameters
        ----------
        key

        Returns
        -------
        Dataset

        """
        return self.datasets[key]

    def info(self):
        """
        Print information about this datasetseries.

        Returns
        -------
        None
        """
        rep = ""
        rep += "class: " + sprint(self.__class__.__name__)
        props = self._repr_dict()
        for k, v in props.items():
            rep += sprint("%s: %s" % (k, v))
        if self.metadata is not None:
            rep += sprint("=== metadata ===")
            # we print the range of each metadata attribute
            minmax_dct = {}
            for mdct in self.metadata.values():
                for k, v in mdct.items():
                    if k not in minmax_dct:
                        minmax_dct[k] = [v, v]
                    else:
                        if not np.isscalar(v):
                            continue  # cannot compare arrays
                        minmax_dct[k][0] = min(minmax_dct[k][0], v)
                        minmax_dct[k][1] = max(minmax_dct[k][1], v)
            for k in minmax_dct:
                reprval1, reprval2 = minmax_dct[k][0], minmax_dct[k][1]
                if isinstance(reprval1, float):
                    reprval1 = "%.2f" % reprval1
                    reprval2 = "%.2f" % reprval2
                m1 = minmax_dct[k][0]
                m2 = minmax_dct[k][1]
                if (not np.isscalar(m1)) or (np.isscalar(m1) and m1 == m2):
                    rep += sprint("%s: %s" % (k, minmax_dct[k][0]))
                else:
                    rep += sprint(
                        "%s: %s -- %s" % (k, minmax_dct[k][0], minmax_dct[k][1])
                    )
            rep += sprint("============")
        print(rep)

    @property
    def data(self) -> None:
        """
        Dummy property to make user aware this is not a Dataset instance.
        Returns
        -------
        None

        """
        raise AttributeError(
            "Series do not have 'data' attribute. Load a dataset from series.get_dataset()."
        )

    def _repr_dict(self) -> Dict[str, str]:
        """
        Return a dictionary of properties to be printed by __repr__ method.

        Returns
        -------
        dict
        """
        props = dict()
        sources = [str(p) for p in self.paths]
        props["source(id=0)"] = sources[0]
        props["Ndatasets"] = len(self.datasets)
        return props

    def __repr__(self) -> str:
        """
        Return a string representation of the datasetseries object.

        Returns
        -------
        str
        """
        props = self._repr_dict()
        clsname = self.__class__.__name__
        result = clsname + "["
        for k, v in props.items():
            result += "%s=%s, " % (k, v)
        result = result[:-2] + "]"
        return result

    @classmethod
    def validate_path(cls, path, *args, **kwargs) -> CandidateStatus:
        """
        Check whether a given path is a valid path for this dataseries class.
        Parameters
        ----------
        path: str
            path to check
        args:
            (unused)
        kwargs:
            (unused)

        Returns
        -------
        CandidateStatus
        """
        return CandidateStatus.NO  # base class dummy

    @classmethod
    def from_directory(
        cls, path, *interface_args, datasetclass=None, pattern=None, **interface_kwargs
    ) -> "DatasetSeries":
        """
        Create a datasetseries instance from a directory.
        Parameters
        ----------
        path: str
            path to directory
        interface_args:
            arguments to pass to interface class
        datasetclass: Optional[Dataset]
            force class to use for dataset instances
        pattern:
            pattern to match files in directory
        interface_kwargs:
            keyword arguments to pass to interface class
        Returns
        -------
        DatasetSeries

        """
        p = Path(path)
        if not (p.exists()):
            raise ValueError("Specified path does not exist.")
        if pattern is None:
            pattern = "*"
        paths = [f for f in p.glob(pattern)]
        return cls(
            paths, *interface_args, datasetclass=datasetclass, **interface_kwargs
        )

    def get_dataset(
        self,
        index: Optional[int] = None,
        name: Optional[str] = None,
        reltol=1e-2,
        **kwargs
    ):
        """
        Get dataset by some metadata property. In the base class, we go by list index.

        Parameters
        ----------
        index: int
            index of dataset to get
        name: str
            name of dataset to get
        reltol:
            relative tolerance for metadata comparison
        kwargs:
            metadata properties to compare for selection

        Returns
        -------
        Dataset

        """
        if index is None and name is None and len(kwargs) == 0:
            raise ValueError("Specify index/name or some parameter to select for.")
        # aliases for index:
        aliases = ["snap", "snapshot"]
        aliases_given = [k for k in aliases if k in kwargs]
        if index is not None:
            aliases_given += [index]
        if len(aliases_given) > 1:
            raise ValueError("Multiple aliases for index specified.")
        for a in aliases_given:
            if kwargs.get(a) is not None:
                index = kwargs.pop(a)
        if index is not None:
            return self.datasets[index]

        if name is not None:
            if self.names is None:
                raise ValueError("No names specified for members of this series.")
            if name not in self.names:
                raise ValueError("Name %s not found in this series." % name)
            return self.datasets[self.names.index(name)]
        if len(kwargs) > 0 and self.metadata is None:
            if self.lazy:
                raise ValueError(
                    "Cannot select by given keys before dataset evaluation."
                )
            raise ValueError("Unknown error.")  # should not happen?

        # find candidates from metadata
        candidates = []
        candidates_props = {}
        props_compare = set()  # save names of fields we want to compare
        for k, v in kwargs.items():
            candidates_props[k] = []
        for i, (j, dm) in enumerate(self.metadata.items()):
            assert int(i) == int(j)
            is_candidate = True
            for k, v in kwargs.items():
                if k not in dm:
                    is_candidate = False
                    continue
                if isinstance(v, int) or isinstance(v, float):
                    candidates_props[k].append(dm[k])
                    props_compare.add(k)
                elif v != dm[k]:
                    is_candidate = False
            if is_candidate:
                candidates.append(i)
            else:  # unroll changes
                for lst in candidates_props.values():
                    if len(lst) > len(candidates):
                        lst.pop()

        # find candidate closest to request
        if len(candidates) == 0:
            raise ValueError("No candidate found for given metadata.")
        idxlist = []
        for k in props_compare:
            idx = np.argmin(np.abs(np.array(candidates_props[k]) - kwargs[k]))
            idxlist.append(idx)
        if len(set(idxlist)) > 1:
            raise ValueError("Ambiguous selection request")
        elif len(idxlist) == 0:
            raise ValueError("No candidate found.")
        index = candidates[idxlist[0]]
        # tolerance check
        for k in props_compare:
            if not np.isclose(kwargs[k], self.metadata[index][k], rtol=reltol):
                msg = (
                    "Candidate does not match tolerance for %s (%s vs %s requested)"
                    % (
                        k,
                        self.metadata[index][k],
                        kwargs[k],
                    )
                )
                raise ValueError(msg)
        return self.get_dataset(index=index)

    @property
    def metadata(self):
        """
        Return metadata dictionary for this series.

        Returns
        -------
        Optional[dict]
            metadata dictionary
        """
        if self._metadata is not None:
            return self._metadata
        fp = self._metadatafile
        if os.path.exists(fp):
            md = json.load(open(fp, "r"))
            ikeys = sorted([int(k) for k in md.keys()])
            mdnew = {}
            for ik in ikeys:
                mdnew[ik] = md[str(ik)]
            self._metadata = mdnew
            return self._metadata
        return None

    @metadata.setter
    def metadata(self, dct):
        """
        Set metadata dictionary for this series, and save to disk.
        Parameters
        ----------
        dct: dict
            metadata dictionary

        Returns
        -------
        None

        """

        class ComplexEncoder(json.JSONEncoder):
            """
            JSON encoder that can handle numpy arrays and bytes.
            """

            def default(self, obj):
                """
                Default recipe for encoding objects.
                Parameters
                ----------
                obj: object
                    object to encode

                Returns
                -------
                object

                """
                if isinstance(obj, np.int64):
                    return int(obj)
                if isinstance(obj, np.int32):
                    return int(obj)
                if isinstance(obj, np.uint32):
                    return int(obj)
                if isinstance(obj, bytes):
                    return obj.decode("utf-8")
                if isinstance(obj, np.ndarray):
                    assert len(obj) < 1000  # do not want large obs here...
                    return list(obj)
                try:
                    return json.JSONEncoder.default(self, obj)
                except TypeError as e:
                    print("obj failing json encoding:", obj)
                    raise e

        self._metadata = dct
        fp = self._metadatafile
        if not os.path.exists(fp):
            json.dump(dct, open(fp, "w"), cls=ComplexEncoder)

data: None property

Dummy property to make user aware this is not a Dataset instance.

Returns:

Type Description
None

metadata property writable

Return metadata dictionary for this series.

Returns:

Type Description
Optional[dict]

metadata dictionary

__getitem__(key)

Return dataset by index.

Parameters:

Name Type Description Default
key
required

Returns:

Type Description
Dataset
Source code in src/scida/series.py
191
192
193
194
195
196
197
198
199
200
201
202
203
def __getitem__(self, key):
    """
    Return dataset by index.
    Parameters
    ----------
    key

    Returns
    -------
    Dataset

    """
    return self.datasets[key]

__init__(paths, *interface_args, datasetclass=None, overwrite_cache=False, lazy=True, names=None, **interface_kwargs)

Parameters:

Name Type Description Default
paths Union[List[str], List[Path]]

list of paths to data

required
interface_args

arguments to pass to interface class

()
datasetclass

class to use for dataset instances

None
overwrite_cache

whether to overwrite existing cache

False
lazy

whether to initialize datasets lazily

True
names

names for datasets

None
interface_kwargs

keyword arguments to pass to interface class

{}
Source code in src/scida/series.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def __init__(
    self,
    paths: Union[List[str], List[Path]],
    *interface_args,
    datasetclass=None,
    overwrite_cache=False,
    lazy=True,  # lazy will only initialize data sets on demand.
    names=None,
    **interface_kwargs
):
    """

    Parameters
    ----------
    paths: list
        list of paths to data
    interface_args:
        arguments to pass to interface class
    datasetclass:
        class to use for dataset instances
    overwrite_cache:
        whether to overwrite existing cache
    lazy:
        whether to initialize datasets lazily
    names:
        names for datasets
    interface_kwargs:
        keyword arguments to pass to interface class
    """
    self.paths = paths
    self.names = names
    self.hash = hash_path("".join([str(p) for p in paths]))
    self._metadata = None
    self._metadatafile = return_cachefile_path(os.path.join(self.hash, "data.json"))
    self.lazy = lazy
    if overwrite_cache and os.path.exists(self._metadatafile):
        os.remove(self._metadatafile)
    for p in paths:
        if not (isinstance(p, Path)):
            p = Path(p)
        if not (p.exists()):
            raise ValueError("Specified path '%s' does not exist." % p)
    dec = delay_init  # lazy loading

    # Catch Mixins and create type:
    ikw = dict(overwrite_cache=overwrite_cache)
    ikw.update(**interface_kwargs)
    mixins = ikw.pop("mixins", [])
    datasetclass = create_datasetclass_with_mixins(datasetclass, mixins)
    self._dataset_cls = datasetclass

    gen = map_interface_args(paths, *interface_args, **ikw)
    self.datasets = [dec(datasetclass)(p, *a, **kw) for p, a, kw in gen]

    if self.metadata is None:
        print("Have not cached this data series. Can take a while.")
        dct = {}
        for i, (path, d) in enumerate(
            tqdm(zip(self.paths, self.datasets), total=len(self.paths))
        ):
            rawmeta = load_metadata(
                path, choose_prefix=True, use_cachefile=not (overwrite_cache)
            )
            # class method does not initiate obj.
            dct[i] = d._clean_metadata_from_raw(rawmeta)
        self.metadata = dct

__init_subclass__(*args, **kwargs)

Register datasetseries subclass in registry.

Parameters:

Name Type Description Default
args

(unused)

()
kwargs

(unused)

{}

Returns:

Type Description
None
Source code in src/scida/series.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def __init_subclass__(cls, *args, **kwargs):
    """
    Register datasetseries subclass in registry.
    Parameters
    ----------
    args:
        (unused)
    kwargs:
        (unused)
    Returns
    -------
    None
    """
    super().__init_subclass__(*args, **kwargs)
    dataseries_type_registry[cls.__name__] = cls

__len__()

Return number of datasets in series.

Returns:

Type Description
int
Source code in src/scida/series.py
182
183
184
185
186
187
188
189
def __len__(self):
    """Return number of datasets in series.

    Returns
    -------
    int
    """
    return len(self.datasets)

__repr__()

Return a string representation of the datasetseries object.

Returns:

Type Description
str
Source code in src/scida/series.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def __repr__(self) -> str:
    """
    Return a string representation of the datasetseries object.

    Returns
    -------
    str
    """
    props = self._repr_dict()
    clsname = self.__class__.__name__
    result = clsname + "["
    for k, v in props.items():
        result += "%s=%s, " % (k, v)
    result = result[:-2] + "]"
    return result

from_directory(path, *interface_args, datasetclass=None, pattern=None, **interface_kwargs) classmethod

Create a datasetseries instance from a directory.

Parameters:

Name Type Description Default
path

path to directory

required
interface_args

arguments to pass to interface class

()
datasetclass

force class to use for dataset instances

None
pattern

pattern to match files in directory

None
interface_kwargs

keyword arguments to pass to interface class

{}

Returns:

Type Description
DatasetSeries
Source code in src/scida/series.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
@classmethod
def from_directory(
    cls, path, *interface_args, datasetclass=None, pattern=None, **interface_kwargs
) -> "DatasetSeries":
    """
    Create a datasetseries instance from a directory.
    Parameters
    ----------
    path: str
        path to directory
    interface_args:
        arguments to pass to interface class
    datasetclass: Optional[Dataset]
        force class to use for dataset instances
    pattern:
        pattern to match files in directory
    interface_kwargs:
        keyword arguments to pass to interface class
    Returns
    -------
    DatasetSeries

    """
    p = Path(path)
    if not (p.exists()):
        raise ValueError("Specified path does not exist.")
    if pattern is None:
        pattern = "*"
    paths = [f for f in p.glob(pattern)]
    return cls(
        paths, *interface_args, datasetclass=datasetclass, **interface_kwargs
    )

get_dataset(index=None, name=None, reltol=0.01, **kwargs)

Get dataset by some metadata property. In the base class, we go by list index.

Parameters:

Name Type Description Default
index Optional[int]

index of dataset to get

None
name Optional[str]

name of dataset to get

None
reltol

relative tolerance for metadata comparison

0.01
kwargs

metadata properties to compare for selection

{}

Returns:

Type Description
Dataset
Source code in src/scida/series.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def get_dataset(
    self,
    index: Optional[int] = None,
    name: Optional[str] = None,
    reltol=1e-2,
    **kwargs
):
    """
    Get dataset by some metadata property. In the base class, we go by list index.

    Parameters
    ----------
    index: int
        index of dataset to get
    name: str
        name of dataset to get
    reltol:
        relative tolerance for metadata comparison
    kwargs:
        metadata properties to compare for selection

    Returns
    -------
    Dataset

    """
    if index is None and name is None and len(kwargs) == 0:
        raise ValueError("Specify index/name or some parameter to select for.")
    # aliases for index:
    aliases = ["snap", "snapshot"]
    aliases_given = [k for k in aliases if k in kwargs]
    if index is not None:
        aliases_given += [index]
    if len(aliases_given) > 1:
        raise ValueError("Multiple aliases for index specified.")
    for a in aliases_given:
        if kwargs.get(a) is not None:
            index = kwargs.pop(a)
    if index is not None:
        return self.datasets[index]

    if name is not None:
        if self.names is None:
            raise ValueError("No names specified for members of this series.")
        if name not in self.names:
            raise ValueError("Name %s not found in this series." % name)
        return self.datasets[self.names.index(name)]
    if len(kwargs) > 0 and self.metadata is None:
        if self.lazy:
            raise ValueError(
                "Cannot select by given keys before dataset evaluation."
            )
        raise ValueError("Unknown error.")  # should not happen?

    # find candidates from metadata
    candidates = []
    candidates_props = {}
    props_compare = set()  # save names of fields we want to compare
    for k, v in kwargs.items():
        candidates_props[k] = []
    for i, (j, dm) in enumerate(self.metadata.items()):
        assert int(i) == int(j)
        is_candidate = True
        for k, v in kwargs.items():
            if k not in dm:
                is_candidate = False
                continue
            if isinstance(v, int) or isinstance(v, float):
                candidates_props[k].append(dm[k])
                props_compare.add(k)
            elif v != dm[k]:
                is_candidate = False
        if is_candidate:
            candidates.append(i)
        else:  # unroll changes
            for lst in candidates_props.values():
                if len(lst) > len(candidates):
                    lst.pop()

    # find candidate closest to request
    if len(candidates) == 0:
        raise ValueError("No candidate found for given metadata.")
    idxlist = []
    for k in props_compare:
        idx = np.argmin(np.abs(np.array(candidates_props[k]) - kwargs[k]))
        idxlist.append(idx)
    if len(set(idxlist)) > 1:
        raise ValueError("Ambiguous selection request")
    elif len(idxlist) == 0:
        raise ValueError("No candidate found.")
    index = candidates[idxlist[0]]
    # tolerance check
    for k in props_compare:
        if not np.isclose(kwargs[k], self.metadata[index][k], rtol=reltol):
            msg = (
                "Candidate does not match tolerance for %s (%s vs %s requested)"
                % (
                    k,
                    self.metadata[index][k],
                    kwargs[k],
                )
            )
            raise ValueError(msg)
    return self.get_dataset(index=index)

info()

Print information about this datasetseries.

Returns:

Type Description
None
Source code in src/scida/series.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def info(self):
    """
    Print information about this datasetseries.

    Returns
    -------
    None
    """
    rep = ""
    rep += "class: " + sprint(self.__class__.__name__)
    props = self._repr_dict()
    for k, v in props.items():
        rep += sprint("%s: %s" % (k, v))
    if self.metadata is not None:
        rep += sprint("=== metadata ===")
        # we print the range of each metadata attribute
        minmax_dct = {}
        for mdct in self.metadata.values():
            for k, v in mdct.items():
                if k not in minmax_dct:
                    minmax_dct[k] = [v, v]
                else:
                    if not np.isscalar(v):
                        continue  # cannot compare arrays
                    minmax_dct[k][0] = min(minmax_dct[k][0], v)
                    minmax_dct[k][1] = max(minmax_dct[k][1], v)
        for k in minmax_dct:
            reprval1, reprval2 = minmax_dct[k][0], minmax_dct[k][1]
            if isinstance(reprval1, float):
                reprval1 = "%.2f" % reprval1
                reprval2 = "%.2f" % reprval2
            m1 = minmax_dct[k][0]
            m2 = minmax_dct[k][1]
            if (not np.isscalar(m1)) or (np.isscalar(m1) and m1 == m2):
                rep += sprint("%s: %s" % (k, minmax_dct[k][0]))
            else:
                rep += sprint(
                    "%s: %s -- %s" % (k, minmax_dct[k][0], minmax_dct[k][1])
                )
        rep += sprint("============")
    print(rep)

validate_path(path, *args, **kwargs) classmethod

Check whether a given path is a valid path for this dataseries class.

Parameters:

Name Type Description Default
path

path to check

required
args

(unused)

()
kwargs

(unused)

{}

Returns:

Type Description
CandidateStatus
Source code in src/scida/series.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
@classmethod
def validate_path(cls, path, *args, **kwargs) -> CandidateStatus:
    """
    Check whether a given path is a valid path for this dataseries class.
    Parameters
    ----------
    path: str
        path to check
    args:
        (unused)
    kwargs:
        (unused)

    Returns
    -------
    CandidateStatus
    """
    return CandidateStatus.NO  # base class dummy

DirectoryCatalog

Bases: object

A catalog consisting of interface instances contained in a directory.

Source code in src/scida/series.py
526
527
528
529
530
531
532
533
534
535
536
537
538
class DirectoryCatalog(object):
    """A catalog consisting of interface instances contained in a directory."""

    def __init__(self, path):
        """
        Initialize a directory catalog.

        Parameters
        ----------
        path: str
            path to directory
        """
        self.path = path

__init__(path)

Initialize a directory catalog.

Parameters:

Name Type Description Default
path

path to directory

required
Source code in src/scida/series.py
529
530
531
532
533
534
535
536
537
538
def __init__(self, path):
    """
    Initialize a directory catalog.

    Parameters
    ----------
    path: str
        path to directory
    """
    self.path = path

HomogeneousSeries

Bases: DatasetSeries

Series consisting of same-type data sets.

Source code in src/scida/series.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
class HomogeneousSeries(DatasetSeries):
    """Series consisting of same-type data sets."""

    def __init__(self, path, **interface_kwargs):
        """
        Initialize a homogeneous series.
        Parameters
        ----------
        path:
            path to data
        interface_kwargs:
            keyword arguments to pass to interface class
        """
        super().__init__()

__init__(path, **interface_kwargs)

Initialize a homogeneous series.

Parameters:

Name Type Description Default
path

path to data

required
interface_kwargs

keyword arguments to pass to interface class

{}
Source code in src/scida/series.py
544
545
546
547
548
549
550
551
552
553
554
def __init__(self, path, **interface_kwargs):
    """
    Initialize a homogeneous series.
    Parameters
    ----------
    path:
        path to data
    interface_kwargs:
        keyword arguments to pass to interface class
    """
    super().__init__()

delay_init(cls)

Decorate class to delay initialization until an attribute is requested.

Parameters:

Name Type Description Default
cls

class to decorate

required

Returns:

Type Description
Delay
Source code in src/scida/series.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def delay_init(cls):
    """
    Decorate class to delay initialization until an attribute is requested.
    Parameters
    ----------
    cls:
        class to decorate

    Returns
    -------
    Delay
    """

    class Delay(cls):
        """
        Delayed initialization of a class. The class is replaced by the actual class
        when an attribute is requested.
        """

        def __init__(self, *args, **kwargs):
            """Store arguments for later initialization."""
            self._args = args
            self._kwargs = kwargs

        def __getattribute__(self, name):
            """Replace the class with the actual class and initialize it if needed."""
            # a few special calls do not trigger initialization:
            specialattrs = [
                "__repr__",
                "__str__",
                "__dir__",
                "__class__",
                "_args",
                "_kwargs",
                # https://github.com/jupyter/notebook/issues/2014
                "_ipython_canary_method_should_not_exist_",
            ]
            if name in specialattrs or name.startswith("_repr"):
                return object.__getattribute__(self, name)
            elif hasattr(cls, name) and inspect.ismethod(getattr(cls, name)):
                # do not need to initialize for class methods
                return getattr(cls, name)
            arg = self._args
            kwarg = self._kwargs
            self.__class__ = cls
            del self._args
            del self._kwargs
            self.__init__(*arg, **kwarg)
            if name == "evaluate_lazy":
                return getattr(self, "__repr__")  # some dummy
            return getattr(self, name)

        def __repr__(self):
            """Return a string representation of the lazy class."""
            return "<Lazy %s>" % cls.__name__

    return Delay

utilities

Some utility functions

copy_to_zarr(fp_in, fp_out, compressor=None)

Reads and converts a scida Dataset to a zarr object on disk

Parameters:

Name Type Description Default
fp_in

object path to convert

required
fp_out

output path

required
compressor

zarr compressor to use, see https://zarr.readthedocs.io/en/stable/tutorial.html#compressors

None

Returns:

Type Description
None
Source code in src/scida/utilities.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def copy_to_zarr(fp_in, fp_out, compressor=None):
    """
    Reads and converts a scida Dataset to a zarr object on disk

    Parameters
    ----------
    fp_in: str
        object path to convert
    fp_out: str
        output path
    compressor:
        zarr compressor to use, see https://zarr.readthedocs.io/en/stable/tutorial.html#compressors

    Returns
    -------
    None
    """
    ds = Dataset(fp_in)
    compressor_dflt = zarr.storage.default_compressor
    zarr.storage.default_compressor = compressor
    ds.save(fp_out, cast_uints=True)
    zarr.storage.default_compressor = compressor_dflt