Skip to content

Advanced API

For now, we list the remaining package documentation here.

Scida is a python package for reading and analyzing scientific big data.

config

Configuration handling.

combine_configs(configs, mode='overwrite_keys')

Combine multiple configurations recursively. Replacing entries in the first config with entries from the latter

Parameters:

Name Type Description Default
configs List[Dict]

The list of configurations to combine.

required
mode

The mode for combining the configurations. "overwrite_keys": overwrite keys in the first config with keys from the latter (default). "overwrite_values": overwrite values in the first config with values from the latter.

'overwrite_keys'

Returns:

Type Description
dict

The combined configuration.

Source code in src/scida/config.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def combine_configs(configs: List[Dict], mode="overwrite_keys") -> Dict:
    """
    Combine multiple configurations recursively.
    Replacing entries in the first config with entries from the latter

    Parameters
    ----------
    configs: list
        The list of configurations to combine.
    mode: str
        The mode for combining the configurations.
        "overwrite_keys": overwrite keys in the first config with keys from the latter (default).
        "overwrite_values": overwrite values in the first config with values from the latter.

    Returns
    -------
    dict
        The combined configuration.
    """
    if mode == "overwrite_values":

        def mergefunc_values(a, b):
            """merge values"""
            if b is None:
                return a
            return b  # just overwrite by latter entry

        mergefunc_keys = None
    elif mode == "overwrite_keys":

        def mergefunc_keys(a, b):
            """merge keys"""
            if b is None:
                return a
            return b

        mergefunc_values = None
    else:
        raise ValueError("Unknown mode '%s'" % mode)
    conf = configs[0]
    for c in configs[1:]:
        merge_dicts_recursively(
            conf, c, mergefunc_keys=mergefunc_keys, mergefunc_values=mergefunc_values
        )
    return conf

copy_defaultconfig(overwrite=False)

Copy the configuration example to the user's home directory.

Parameters:

Name Type Description Default
overwrite

Overwrite existing configuration file.

False

Returns:

Type Description
None
Source code in src/scida/config.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def copy_defaultconfig(overwrite=False) -> None:
    """
    Copy the configuration example to the user's home directory.
    Parameters
    ----------
    overwrite: bool
        Overwrite existing configuration file.

    Returns
    -------
    None
    """

    path_user = os.path.expanduser("~")
    path_confdir = os.path.join(path_user, ".config/scida")
    if not os.path.exists(path_confdir):
        os.makedirs(path_confdir, exist_ok=True)
    path_conf = os.path.join(path_confdir, "config.yaml")
    if os.path.exists(path_conf) and not overwrite:
        raise ValueError("Configuration file already exists at '%s'" % path_conf)
    with importlib.resources.path("scida.configfiles", "config.yaml") as fp:
        with open(fp, "r") as file:
            content = file.read()
            with open(path_conf, "w") as newfile:
                newfile.write(content)

get_config(reload=False, update_global=True)

Load the configuration from the default path.

Parameters:

Name Type Description Default
reload bool

Reload the configuration, even if it has already been loaded.

False
update_global

Update the global configuration dictionary.

True

Returns:

Type Description
dict

The configuration dictionary.

Source code in src/scida/config.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_config(reload: bool = False, update_global=True) -> dict:
    """
    Load the configuration from the default path.

    Parameters
    ----------
    reload: bool
        Reload the configuration, even if it has already been loaded.
    update_global: bool
        Update the global configuration dictionary.

    Returns
    -------
    dict
        The configuration dictionary.
    """
    global _conf
    prefix = "SCIDA_"
    envconf = {
        k.replace(prefix, "").lower(): v
        for k, v in os.environ.items()
        if k.startswith(prefix)
    }

    # in any case, we make sure that there is some config in the default path.
    path_confdir = _access_confdir()
    path_conf = os.path.join(path_confdir, "config.yaml")

    # next, we load the config from the default path, unless explicitly overridden.
    path = envconf.pop("config_path", None)
    if path is None:
        path = path_conf
    if not reload and len(_conf) > 0:
        return _conf
    config = get_config_fromfile(path)
    if config.get("copied_default", False):
        print(
            "Warning! Using default configuration. Please adjust/replace in '%s'."
            % path
        )

    config.update(**envconf)
    if update_global:
        _conf = config
    return config

get_config_fromfile(resource)

Load config from a YAML file.

Parameters:

Name Type Description Default
resource str

The name of the resource or file path.

required

Returns:

Type Description
None
Source code in src/scida/config.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def get_config_fromfile(resource: str) -> Dict:
    """
    Load config from a YAML file.
    Parameters
    ----------
    resource
        The name of the resource or file path.

    Returns
    -------
    None
    """
    if resource == "":
        raise ValueError("Config name cannot be empty.")
    # order (in descending order of priority):
    # 1. absolute path?
    path = os.path.expanduser(resource)
    if os.path.isabs(path):
        with open(path, "r") as file:
            conf = yaml.safe_load(file)
        return conf
    # 2. non-absolute path?
    # 2.1. check ~/.config/scida/
    bpath = os.path.expanduser("~/.config/scida")
    path = os.path.join(bpath, resource)
    if os.path.isfile(path):
        with open(path, "r") as file:
            conf = yaml.safe_load(file)
        return conf
    # 2.2 check scida package resources
    resource_path = "scida.configfiles"
    resource_elements = resource.split("/")
    rname = resource_elements[-1]
    if len(resource_elements) > 1:
        resource_path += "." + ".".join(resource_elements[:-1])
    with importlib.resources.path(resource_path, rname) as fp:
        with open(fp, "r") as file:
            conf = yaml.safe_load(file)
    return conf

get_config_fromfiles(paths, subconf_keys=None)

Load and merge multiple YAML config files

Parameters:

Name Type Description Default
paths List[str]

Paths to the config files.

required
subconf_keys Optional[List[str]]

The keys to the correct sub configuration within each config.

None

Returns:

Type Description
None
Source code in src/scida/config.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def get_config_fromfiles(paths: List[str], subconf_keys: Optional[List[str]] = None):
    """
    Load and merge multiple YAML config files
    Parameters
    ----------
    paths
        Paths to the config files.
    subconf_keys
        The keys to the correct sub configuration within each config.

    Returns
    -------
    None
    """
    confs = []
    for path in paths:
        confs.append(get_config_fromfile(path))
    conf = {}
    for confdict in confs:
        conf = merge_dicts_recursively(conf, confdict)
    return conf

get_simulationconfig()

Get the simulation configuration. Search regular user config file, scida simulation config file, and user's simulation config file.

Returns:

Type Description
dict

The simulation configuration.

Source code in src/scida/config.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_simulationconfig():
    """
    Get the simulation configuration.
    Search regular user config file, scida simulation config file, and user's simulation config file.

    Returns
    -------
    dict
        The simulation configuration.
    """
    conf_user = get_config()
    conf_sims = get_config_fromfile("simulations.yaml")
    conf_sims_user = _get_simulationconfig_user()

    confs_base = [conf_sims, conf_user]
    if conf_sims_user is not None:
        confs_base.append(conf_sims_user)
    # we only want to overwrite keys within "data", otherwise no merging of simkeys would take place
    confs = []
    for c in confs_base:
        if "data" in c:
            confs.append(c["data"])

    conf_sims = combine_configs(confs, mode="overwrite_keys")
    conf_sims = {"data": conf_sims}

    return conf_sims

merge_dicts_recursively(dict_a, dict_b, path=None, mergefunc_keys=None, mergefunc_values=None)

Merge two dictionaries recursively.

Parameters:

Name Type Description Default
dict_a Dict

The first dictionary.

required
dict_b Dict

The second dictionary.

required
path Optional[List]

The path to the current node.

None
mergefunc_keys Optional[callable]

The function to use for merging dict keys. If None, we recursively enter the dictionary.

None
mergefunc_values Optional[callable]

The function to use for merging dict values. If None, collisions will raise an exception.

None

Returns:

Type Description
dict
Source code in src/scida/config.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def merge_dicts_recursively(
    dict_a: Dict,
    dict_b: Dict,
    path: Optional[List] = None,
    mergefunc_keys: Optional[callable] = None,
    mergefunc_values: Optional[callable] = None,
) -> Dict:
    """
    Merge two dictionaries recursively.
    Parameters
    ----------
    dict_a
        The first dictionary.
    dict_b
        The second dictionary.
    path
        The path to the current node.
    mergefunc_keys: callable
        The function to use for merging dict keys.
        If None, we recursively enter the dictionary.
    mergefunc_values: callable
        The function to use for merging dict values.
        If None, collisions will raise an exception.

    Returns
    -------
    dict
    """
    if path is None:
        path = []
    for key in dict_b:
        if key in dict_a:
            if mergefunc_keys is not None:
                dict_a[key] = mergefunc_keys(dict_a[key], dict_b[key])
            elif isinstance(dict_a[key], dict) and isinstance(dict_b[key], dict):
                merge_dicts_recursively(
                    dict_a[key],
                    dict_b[key],
                    path + [str(key)],
                    mergefunc_keys=mergefunc_keys,
                    mergefunc_values=mergefunc_values,
                )
            elif dict_a[key] == dict_b[key]:
                pass  # same leaf value
            else:
                if mergefunc_values is not None:
                    dict_a[key] = mergefunc_values(dict_a[key], dict_b[key])
                else:
                    raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
        else:
            dict_a[key] = dict_b[key]
    return dict_a

convenience

download_and_extract(url, path, progressbar=True, overwrite=False)

Download and extract a file from a given url.

Parameters:

Name Type Description Default
url str

The url to download from.

required
path Path

The path to download to.

required
progressbar bool

Whether to show a progress bar.

True
overwrite bool

Whether to overwrite an existing file.

False

Returns:

Type Description
str

The path to the downloaded and extracted file(s).

Source code in src/scida/convenience.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def download_and_extract(
    url: str, path: pathlib.Path, progressbar: bool = True, overwrite: bool = False
):
    """
    Download and extract a file from a given url.
    Parameters
    ----------
    url: str
        The url to download from.
    path: pathlib.Path
        The path to download to.
    progressbar: bool
        Whether to show a progress bar.
    overwrite: bool
        Whether to overwrite an existing file.
    Returns
    -------
    str
        The path to the downloaded and extracted file(s).
    """
    if path.exists() and not overwrite:
        raise ValueError("Target path '%s' already exists." % path)
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        totlength = int(r.headers.get("content-length", 0))
        lread = 0
        t1 = time.time()
        with open(path, "wb") as f:
            for chunk in r.iter_content(chunk_size=2**22):  # chunks of 4MB
                t2 = time.time()
                f.write(chunk)
                lread += len(chunk)
                if progressbar:
                    rate = (lread / 2**20) / (t2 - t1)
                    sys.stdout.write(
                        "\rDownloaded %.2f/%.2f Megabytes (%.2f%%, %.2f MB/s)"
                        % (
                            lread / 2**20,
                            totlength / 2**20,
                            100.0 * lread / totlength,
                            rate,
                        )
                    )
                    sys.stdout.flush()
        sys.stdout.write("\n")
    tar = tarfile.open(path, "r:gz")
    for t in tar:
        if t.isdir():
            t.mode = int("0755", base=8)
        else:
            t.mode = int("0644", base=8)
    tar.extractall(path.parents[0])
    foldername = tar.getmembers()[0].name  # parent folder of extracted tar.gz
    tar.close()
    os.remove(path)
    return os.path.join(path.parents[0], foldername)

find_path(path, overwrite=False)

Find path to dataset.

Parameters:

Name Type Description Default
path
required
overwrite

Only for remote datasets. Whether to overwrite an existing download.

False

Returns:

Type Description
str
Source code in src/scida/convenience.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def find_path(path, overwrite=False) -> str:
    """
    Find path to dataset.

    Parameters
    ----------
    path: str
    overwrite: bool
        Only for remote datasets. Whether to overwrite an existing download.

    Returns
    -------
    str

    """
    config = get_config()
    path = os.path.expanduser(path)
    if os.path.exists(path):
        # datasets on disk
        pass
    elif len(path.split(":")) > 1:
        # check different alternative backends
        databackend = path.split("://")[0]
        dataname = path.split("://")[1]
        if databackend in ["http", "https"]:
            # dataset on the internet
            savepath = config.get("download_path", None)
            if savepath is None:
                print(
                    "Have not specified 'download_path' in config. Using 'cache_path' instead."
                )
                savepath = config.get("cache_path")
            savepath = os.path.expanduser(savepath)
            savepath = pathlib.Path(savepath)
            savepath.mkdir(parents=True, exist_ok=True)
            urlhash = str(
                int(hashlib.sha256(path.encode("utf-8")).hexdigest(), 16) % 10**8
            )
            savepath = savepath / ("download" + urlhash)
            filename = "archive.tar.gz"
            if not savepath.exists():
                os.makedirs(savepath, exist_ok=True)
            elif overwrite:
                # delete all files in folder
                for f in os.listdir(savepath):
                    fp = os.path.join(savepath, f)
                    if os.path.isfile(fp):
                        os.unlink(fp)
                    else:
                        shutil.rmtree(fp)
            foldercontent = [f for f in savepath.glob("*")]
            if len(foldercontent) == 0:
                savepath = savepath / filename
                extractpath = download_and_extract(
                    path, savepath, progressbar=True, overwrite=overwrite
                )
            else:
                extractpath = savepath
            extractpath = pathlib.Path(extractpath)

            # count folders in savepath
            nfolders = len([f for f in extractpath.glob("*") if f.is_dir()])
            nobjects = len([f for f in extractpath.glob("*") if f.is_dir()])
            if nobjects == 1 and nfolders == 1:
                extractpath = (
                    extractpath / [f for f in extractpath.glob("*") if f.is_dir()][0]
                )
            path = extractpath
        elif databackend == "testdata":
            path = get_testdata(dataname)
        else:
            # potentially custom dataset.
            resources = config.get("resources", {})
            if databackend not in resources:
                raise ValueError("Unknown resource '%s'" % databackend)
            r = resources[databackend]
            if dataname not in r:
                raise ValueError(
                    "Unknown dataset '%s' in resource '%s'" % (dataname, databackend)
                )
            path = os.path.expanduser(r[dataname]["path"])
    else:
        found = False
        if "datafolders" in config:
            for folder in config["datafolders"]:
                folder = os.path.expanduser(folder)
                if os.path.exists(os.path.join(folder, path)):
                    path = os.path.join(folder, path)
                    found = True
                    break
        if not found:
            raise ValueError("Specified path '%s' unknown." % path)
    return path

get_dataset(name=None, props=None)

Get dataset by name or properties.

Parameters:

Name Type Description Default
name

Name or alias of dataset.

None
props

Properties to match.

None

Returns:

Name Type Description
str

Dataset name.

Source code in src/scida/convenience.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def get_dataset(name=None, props=None):
    """
    Get dataset by name or properties.
    Parameters
    ----------
    name: Optional[str]
        Name or alias of dataset.
    props: Optional[dict]
        Properties to match.

    Returns
    -------
    str:
        Dataset name.

    """
    dnames = get_dataset_candidates(name=name, props=props)
    if len(dnames) > 1:
        raise ValueError("Too many dataset candidates.")
    elif len(dnames) == 0:
        raise ValueError("No dataset candidate found.")
    return dnames[0]

get_dataset_by_name(name)

Get dataset name from alias or name found in the configuration files.

Parameters:

Name Type Description Default
name str

Name or alias of dataset.

required

Returns:

Type Description
str
Source code in src/scida/convenience.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def get_dataset_by_name(name: str) -> Optional[str]:
    """
    Get dataset name from alias or name found in the configuration files.

    Parameters
    ----------
    name: str
        Name or alias of dataset.

    Returns
    -------
    str
    """
    dname = None
    c = get_config()
    if "datasets" not in c:
        return dname
    datasets = copy.deepcopy(c["datasets"])
    if name in datasets:
        dname = name
    else:
        # could still be an alias
        for k, v in datasets.items():
            if "aliases" not in v:
                continue
            if name in v["aliases"]:
                dname = k
                break
    return dname

get_dataset_candidates(name=None, props=None)

Get dataset candidates by name or properties.

Parameters:

Name Type Description Default
name

Name or alias of dataset.

None
props

Properties to match.

None

Returns:

Type Description
list[str]:

List of candidate dataset names.

Source code in src/scida/convenience.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def get_dataset_candidates(name=None, props=None):
    """
    Get dataset candidates by name or properties.

    Parameters
    ----------
    name: Optional[str]
        Name or alias of dataset.
    props: Optional[dict]
        Properties to match.

    Returns
    -------
    list[str]:
        List of candidate dataset names.

    """
    if name is not None:
        dnames = [get_dataset_by_name(name)]
        return dnames
    if props is not None:
        dnames = get_datasets_by_props(**props)
        return dnames
    raise ValueError("Need to specify name or properties.")

get_datasets_by_props(**kwargs)

Get dataset names by properties.

Parameters:

Name Type Description Default
kwargs

Properties to match.

{}

Returns:

Type Description
list[str]:

List of dataset names.

Source code in src/scida/convenience.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def get_datasets_by_props(**kwargs):
    """
    Get dataset names by properties.

    Parameters
    ----------
    kwargs: dict
        Properties to match.

    Returns
    -------
    list[str]:
        List of dataset names.
    """
    dnames = []
    c = get_config()
    if "datasets" not in c:
        return dnames
    datasets = copy.deepcopy(c["datasets"])
    for k, v in datasets.items():
        props = v.get("properties", {})
        match = True
        for pk, pv in kwargs.items():
            if pk not in props:
                match = False
                break
            if props[pk] != pv:
                match = False
                break
        if match:
            dnames.append(k)
    return dnames

get_testdata(name)

Get path to test data identifier.

Parameters:

Name Type Description Default
name str

Name of test data.

required

Returns:

Type Description
str
Source code in src/scida/convenience.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def get_testdata(name: str) -> str:
    """
    Get path to test data identifier.

    Parameters
    ----------
    name: str
        Name of test data.
    Returns
    -------
    str
    """
    config = get_config()
    tdpath: Optional[str] = config.get("testdata_path", None)
    if tdpath is None:
        raise ValueError("Test data directory not specified in configuration")
    if not os.path.isdir(tdpath):
        raise ValueError("Invalid test data path")
    res = {f: os.path.join(tdpath, f) for f in os.listdir(tdpath)}
    if name not in res.keys():
        raise ValueError("Specified test data not available.")
    return res[name]

load(path, units=True, unitfile='', overwrite=False, force_class=None, **kwargs)

Load a dataset or dataset series from a given path. This function will automatically determine the best-matching class to use and return the initialized instance.

Parameters:

Name Type Description Default
path str

Path to dataset or dataset series. Usually the base folder containing all files of a given dataset/series.

required
units Union[bool, str]

Whether to load units.

True
unitfile str

Can explicitly pass path to a unitfile to use.

''
overwrite bool

Whether to overwrite an existing cache.

False
force_class Optional[object]

Force a specific class to be used.

None
kwargs

Additional keyword arguments to pass to the class.

{}

Returns:

Type Description
Union[Dataset, DatasetSeries]:

Initialized dataset or dataset series.

Source code in src/scida/convenience.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def load(
    path: str,
    units: Union[bool, str] = True,
    unitfile: str = "",
    overwrite: bool = False,
    force_class: Optional[object] = None,
    **kwargs
):
    """
    Load a dataset or dataset series from a given path.
    This function will automatically determine the best-matching
    class to use and return the initialized instance.

    Parameters
    ----------
    path: str
        Path to dataset or dataset series. Usually the base folder containing all files of a given dataset/series.
    units: bool
        Whether to load units.
    unitfile: str
        Can explicitly pass path to a unitfile to use.
    overwrite: bool
        Whether to overwrite an existing cache.
    force_class: object
        Force a specific class to be used.
    kwargs: dict
        Additional keyword arguments to pass to the class.

    Returns
    -------
    Union[Dataset, DatasetSeries]:
        Initialized dataset or dataset series.
    """

    path = find_path(path, overwrite=overwrite)

    if "catalog" in kwargs:
        c = kwargs["catalog"]
        query_path = True
        query_path &= c is not None
        query_path &= not isinstance(c, bool)
        query_path &= not isinstance(c, list)
        query_path &= c != "none"
        if query_path:
            kwargs["catalog"] = find_path(c, overwrite=overwrite)

    # determine dataset class
    reg = dict()
    reg.update(**dataset_type_registry)
    reg.update(**dataseries_type_registry)

    path = os.path.realpath(path)
    cls = _determine_type(path, **kwargs)[1][0]

    msg = "Dataset is identified as '%s' via _determine_type." % cls
    log.debug(msg)

    # any identifying metadata?
    classtype = "dataset"
    if issubclass(cls, DatasetSeries):
        classtype = "series"
    cls_simconf = _determine_type_from_simconfig(path, classtype=classtype, reg=reg)

    if cls_simconf and not issubclass(cls, cls_simconf):
        oldcls = cls
        cls = cls_simconf
        if oldcls != cls:
            msg = "Dataset is identified as '%s' via the simulation config replacing prior candidate '%s'."
            log.debug(msg % (cls, oldcls))
        else:
            msg = "Dataset is identified as '%s' via the simulation config, identical to prior candidate."
            log.debug(msg % cls)

    if force_class is not None:
        cls = force_class

    # determine additional mixins not set by class
    mixins = []
    if hasattr(cls, "_unitfile"):
        unitfile = cls._unitfile
    if unitfile:
        if not units:
            units = True
        kwargs["unitfile"] = unitfile

    if units:
        mixins.append(UnitMixin)
        kwargs["units"] = units

    msg = "Inconsistent overwrite_cache, please only use 'overwrite' in load()."
    assert kwargs.get("overwrite_cache", overwrite) == overwrite, msg
    kwargs["overwrite_cache"] = overwrite

    # we append since unit mixin is added outside of this func right now
    metadata_raw = dict()
    if classtype == "dataset":
        metadata_raw = load_metadata(path, fileprefix=None)
    other_mixins = _determine_mixins(path=path, metadata_raw=metadata_raw)
    mixins += other_mixins

    log.debug("Adding mixins '%s' to dataset." % mixins)
    if hasattr(cls, "_mixins"):
        cls_mixins = cls._mixins
        for m in cls_mixins:
            # remove duplicates
            if m in mixins:
                mixins.remove(m)

    instance = cls(path, mixins=mixins, **kwargs)
    return instance

customs

arepo

MTNG

dataset

Support for MTNG-Arepo datasets, see https://www.mtng-project.org/

MTNGArepoCatalog

Bases: ArepoCatalog

A dataset representing a MTNG-Arepo catalog.

Source code in src/scida/customs/arepo/MTNG/dataset.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class MTNGArepoCatalog(ArepoCatalog):
    """
    A dataset representing a MTNG-Arepo catalog.
    """

    _fileprefix = "fof_subhalo_tab"

    def __init__(self, *args, **kwargs):
        """
        Initialize an MTNGArepoCatalog object.

        Parameters
        ----------
        args: list
        kwargs: dict
        """
        kwargs["iscatalog"] = True
        if "fileprefix" not in kwargs:
            kwargs["fileprefix"] = "fof_subhalo_tab"
        kwargs["choose_prefix"] = True
        super().__init__(*args, **kwargs)

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path as a candidate for the MTNG-Arepo catalog class.

        Parameters
        ----------
        path: str
            Path to validate.
        args: list
        kwargs: dict

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this dataset class.
        """
        tkwargs = dict(fileprefix=cls._fileprefix)
        tkwargs.update(**kwargs)
        valid = super().validate_path(path, *args, **tkwargs)
        if valid == CandidateStatus.NO:
            return valid
        metadata_raw = load_metadata(path, **tkwargs)
        if "/Config" not in metadata_raw:
            return CandidateStatus.NO
        if "MTNG" not in metadata_raw["/Config"]:
            return CandidateStatus.NO
        return CandidateStatus.YES
__init__(*args, **kwargs)

Initialize an MTNGArepoCatalog object.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/customs/arepo/MTNG/dataset.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __init__(self, *args, **kwargs):
    """
    Initialize an MTNGArepoCatalog object.

    Parameters
    ----------
    args: list
    kwargs: dict
    """
    kwargs["iscatalog"] = True
    if "fileprefix" not in kwargs:
        kwargs["fileprefix"] = "fof_subhalo_tab"
    kwargs["choose_prefix"] = True
    super().__init__(*args, **kwargs)
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for the MTNG-Arepo catalog class.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Path to validate.

required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this dataset class.

Source code in src/scida/customs/arepo/MTNG/dataset.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path as a candidate for the MTNG-Arepo catalog class.

    Parameters
    ----------
    path: str
        Path to validate.
    args: list
    kwargs: dict

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this dataset class.
    """
    tkwargs = dict(fileprefix=cls._fileprefix)
    tkwargs.update(**kwargs)
    valid = super().validate_path(path, *args, **tkwargs)
    if valid == CandidateStatus.NO:
        return valid
    metadata_raw = load_metadata(path, **tkwargs)
    if "/Config" not in metadata_raw:
        return CandidateStatus.NO
    if "MTNG" not in metadata_raw["/Config"]:
        return CandidateStatus.NO
    return CandidateStatus.YES
MTNGArepoSnapshot

Bases: ArepoSnapshot

MTNGArepoSnapshot is a snapshot class for the MTNG project.

Source code in src/scida/customs/arepo/MTNG/dataset.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class MTNGArepoSnapshot(ArepoSnapshot):
    """
    MTNGArepoSnapshot is a snapshot class for the MTNG project.
    """

    _fileprefix_catalog = "fof_subhalo_tab"
    _fileprefix = "snapshot_"  # underscore is important!

    def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
        """
        Initialize an MTNGArepoSnapshot object.

        Parameters
        ----------
        path: str
            Path to the snapshot folder, should contain "output" folder.
        chunksize: int
            Chunksize for the data.
        catalog: str
            Explicitly state catalog path to use.
        kwargs:
            Additional keyword arguments.
        """
        tkwargs = dict(
            fileprefix=self._fileprefix,
            fileprefix_catalog=self._fileprefix_catalog,
            catalog_cls=MTNGArepoCatalog,
        )
        tkwargs.update(**kwargs)

        # in MTNG, we have two kinds of snapshots:
        # 1. regular (prefix: "snapshot_"), contains all particle types
        # 2. mostbound (prefix: "snapshot-prevmostboundonly_"), only contains DM particles
        # Most snapshots are of type 2, but some selected snapshots have type 1 and type 2.

        # attempt to load regular snapshot
        super().__init__(path, chunksize=chunksize, catalog=catalog, **tkwargs)
        # need to add mtng unit peculiarities
        # later unit file takes precedence
        self._defaultunitfiles += ["units/mtng.yaml"]

        if tkwargs["fileprefix"] == "snapshot-prevmostboundonly_":
            # this is a mostbound snapshot, so we are done
            return

        # next, attempt to load mostbound snapshot. This is done by loading into sub-object.
        self.mostbound = None
        tkwargs.update(fileprefix="snapshot-prevmostboundonly_", catalog="none")
        self.mostbound = MTNGArepoSnapshot(path, chunksize=chunksize, **tkwargs)
        # hacky: remove unused containers from mostbound snapshot
        for k in [
            "PartType0",
            "PartType2",
            "PartType3",
            "PartType4",
            "PartType5",
            "Group",
            "Subhalo",
        ]:
            if k in self.mostbound.data:
                del self.mostbound.data[k]
        self.merge_data(self.mostbound, fieldname_suffix="_mostbound")

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path as a candidate for the MTNG-Arepo snapshot class.

        Parameters
        ----------
        path:
            Path to validate.
        args:  list
        kwargs: dict

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this dataset class.

        """
        tkwargs = dict(
            fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
        )
        tkwargs.update(**kwargs)
        try:
            valid = super().validate_path(path, *args, **tkwargs)
        except ValueError:
            valid = CandidateStatus.NO
            # might raise ValueError in case of partial snap

        if valid == CandidateStatus.NO:
            # check for partial snap
            tkwargs.update(fileprefix="snapshot-prevmostboundonly_")
            try:
                valid = super().validate_path(path, *args, **tkwargs)
            except ValueError:
                valid = CandidateStatus.NO

        if valid == CandidateStatus.NO:
            return valid
        metadata_raw = load_metadata(path, **tkwargs)
        if "/Config" not in metadata_raw:
            return CandidateStatus.NO
        if "MTNG" not in metadata_raw["/Config"]:
            return CandidateStatus.NO
        if "Ngroups_Total" in metadata_raw["/Header"]:
            return CandidateStatus.NO  # this is a catalog
        return CandidateStatus.YES
__init__(path, chunksize='auto', catalog=None, **kwargs)

Initialize an MTNGArepoSnapshot object.

Parameters:

Name Type Description Default
path

Path to the snapshot folder, should contain "output" folder.

required
chunksize

Chunksize for the data.

'auto'
catalog

Explicitly state catalog path to use.

None
kwargs

Additional keyword arguments.

{}
Source code in src/scida/customs/arepo/MTNG/dataset.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
    """
    Initialize an MTNGArepoSnapshot object.

    Parameters
    ----------
    path: str
        Path to the snapshot folder, should contain "output" folder.
    chunksize: int
        Chunksize for the data.
    catalog: str
        Explicitly state catalog path to use.
    kwargs:
        Additional keyword arguments.
    """
    tkwargs = dict(
        fileprefix=self._fileprefix,
        fileprefix_catalog=self._fileprefix_catalog,
        catalog_cls=MTNGArepoCatalog,
    )
    tkwargs.update(**kwargs)

    # in MTNG, we have two kinds of snapshots:
    # 1. regular (prefix: "snapshot_"), contains all particle types
    # 2. mostbound (prefix: "snapshot-prevmostboundonly_"), only contains DM particles
    # Most snapshots are of type 2, but some selected snapshots have type 1 and type 2.

    # attempt to load regular snapshot
    super().__init__(path, chunksize=chunksize, catalog=catalog, **tkwargs)
    # need to add mtng unit peculiarities
    # later unit file takes precedence
    self._defaultunitfiles += ["units/mtng.yaml"]

    if tkwargs["fileprefix"] == "snapshot-prevmostboundonly_":
        # this is a mostbound snapshot, so we are done
        return

    # next, attempt to load mostbound snapshot. This is done by loading into sub-object.
    self.mostbound = None
    tkwargs.update(fileprefix="snapshot-prevmostboundonly_", catalog="none")
    self.mostbound = MTNGArepoSnapshot(path, chunksize=chunksize, **tkwargs)
    # hacky: remove unused containers from mostbound snapshot
    for k in [
        "PartType0",
        "PartType2",
        "PartType3",
        "PartType4",
        "PartType5",
        "Group",
        "Subhalo",
    ]:
        if k in self.mostbound.data:
            del self.mostbound.data[k]
    self.merge_data(self.mostbound, fieldname_suffix="_mostbound")
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for the MTNG-Arepo snapshot class.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Path to validate.

required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this dataset class.

Source code in src/scida/customs/arepo/MTNG/dataset.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path as a candidate for the MTNG-Arepo snapshot class.

    Parameters
    ----------
    path:
        Path to validate.
    args:  list
    kwargs: dict

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this dataset class.

    """
    tkwargs = dict(
        fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
    )
    tkwargs.update(**kwargs)
    try:
        valid = super().validate_path(path, *args, **tkwargs)
    except ValueError:
        valid = CandidateStatus.NO
        # might raise ValueError in case of partial snap

    if valid == CandidateStatus.NO:
        # check for partial snap
        tkwargs.update(fileprefix="snapshot-prevmostboundonly_")
        try:
            valid = super().validate_path(path, *args, **tkwargs)
        except ValueError:
            valid = CandidateStatus.NO

    if valid == CandidateStatus.NO:
        return valid
    metadata_raw = load_metadata(path, **tkwargs)
    if "/Config" not in metadata_raw:
        return CandidateStatus.NO
    if "MTNG" not in metadata_raw["/Config"]:
        return CandidateStatus.NO
    if "Ngroups_Total" in metadata_raw["/Header"]:
        return CandidateStatus.NO  # this is a catalog
    return CandidateStatus.YES

TNGcluster

dataset
TNGClusterSelector

Bases: Selector

Selector for TNGClusterSnapshot. Can select for zoomID, which selects a given zoom target. Can specify withfuzz=True to include the "fuzz" particles for a given zoom target. Can specify onlyfuzz=True to only return the "fuzz" particles for a given zoom target.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class TNGClusterSelector(Selector):
    """
    Selector for TNGClusterSnapshot.  Can select for zoomID, which selects a given zoom target.
    Can specify withfuzz=True to include the "fuzz" particles for a given zoom target.
    Can specify onlyfuzz=True to only return the "fuzz" particles for a given zoom target.
    """

    def __init__(self) -> None:
        """
        Initialize the selector.
        """
        super().__init__()
        self.keys = ["zoomID", "withfuzz", "onlyfuzz"]

    def prepare(self, *args, **kwargs) -> None:
        """
        Prepare the selector.

        Parameters
        ----------
        args: list
        kwargs: dict

        Returns
        -------
        None
        """
        snap: TNGClusterSnapshot = args[0]
        zoom_id = kwargs.get("zoomID", None)
        fuzz = kwargs.get("withfuzz", None)
        onlyfuzz = kwargs.get("onlylfuzz", None)
        if zoom_id is None:
            return
        if zoom_id < 0 or zoom_id > (snap.ntargets - 1):
            raise ValueError("zoomID must be in range 0-%i" % (snap.ntargets - 1))

        for p in self.data_backup:
            if p.startswith("PartType"):
                key = "particles"
            elif p == "Group":
                key = "groups"
            elif p == "Subhalo":
                key = "subgroups"
            else:
                continue
            lengths = snap.lengths_zoom[key][zoom_id]
            offsets = snap.offsets_zoom[key][zoom_id]
            length_fuzz = None
            offset_fuzz = None

            if fuzz and key == "particles":  # fuzz files only for particles
                lengths_fuzz = snap.lengths_zoom[key][zoom_id + snap.ntargets]
                offsets_fuzz = snap.offsets_zoom[key][zoom_id + snap.ntargets]

                splt = p.split("PartType")
                pnum = int(splt[1])
                offset_fuzz = offsets_fuzz[pnum]
                length_fuzz = lengths_fuzz[pnum]

            if key == "particles":
                splt = p.split("PartType")
                pnum = int(splt[1])
                offset = offsets[pnum]
                length = lengths[pnum]
            else:
                offset = offsets
                length = lengths

            def get_slicedarr(
                v, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
            ):
                """
                Get a sliced dask array for a given (length, offset) and (length_fuzz, offset_fuzz).

                Parameters
                ----------
                v: da.Array
                    The array to slice.
                offset: int
                length: int
                offset_fuzz: int
                length_fuzz: int
                key: str
                    ?
                fuzz: bool

                Returns
                -------
                da.Array
                    The sliced array.
                """
                arr = v[offset : offset + length]
                if offset_fuzz is not None:
                    arr_fuzz = v[offset_fuzz : offset_fuzz + length_fuzz]
                    if onlyfuzz:
                        arr = arr_fuzz
                    else:
                        arr = np.concatenate([arr, arr_fuzz])
                return arr

            def get_slicedfunc(
                func, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
            ):
                """
                Slice a functions output for a given (length, offset) and (length_fuzz, offset_fuzz).

                Parameters
                ----------
                func: callable
                offset: int
                length: int
                offset_fuzz: int
                length_fuzz: int
                key: str
                fuzz: bool

                Returns
                -------
                callable
                    The sliced function.
                """

                def newfunc(
                    arrs, o=offset, ln=length, of=offset_fuzz, lnf=length_fuzz, **kwargs
                ):
                    arr_all = func(arrs, **kwargs)
                    arr = arr_all[o : o + ln]
                    if of is None:
                        return arr
                    arr_fuzz = arr_all[of : of + lnf]
                    if onlyfuzz:
                        return arr_fuzz
                    else:
                        return np.concatenate([arr, arr_fuzz])

                return newfunc

            # need to evaluate without recipes first
            for k, v in self.data_backup[p].items(withrecipes=False):
                self.data[p][k] = get_slicedarr(
                    v, offset, length, offset_fuzz, length_fuzz, key, fuzz
                )

            for k, v in self.data_backup[p].items(
                withfields=False, withrecipes=True, evaluate=False
            ):
                if not isinstance(v, FieldRecipe):
                    continue  # already evaluated, no need to port recipe (?)
                rcp: FieldRecipe = v
                func = get_slicedfunc(
                    v.func, offset, length, offset_fuzz, length_fuzz, key, fuzz
                )
                newrcp = DerivedFieldRecipe(rcp.name, func)
                newrcp.type = rcp.type
                newrcp.description = rcp.description
                newrcp.units = rcp.units
                self.data[p][k] = newrcp
        snap.data = self.data
__init__()

Initialize the selector.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
20
21
22
23
24
25
def __init__(self) -> None:
    """
    Initialize the selector.
    """
    super().__init__()
    self.keys = ["zoomID", "withfuzz", "onlyfuzz"]
prepare(*args, **kwargs)

Prepare the selector.

Parameters:

Name Type Description Default
args
()
kwargs
{}

Returns:

Type Description
None
Source code in src/scida/customs/arepo/TNGcluster/dataset.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def prepare(self, *args, **kwargs) -> None:
    """
    Prepare the selector.

    Parameters
    ----------
    args: list
    kwargs: dict

    Returns
    -------
    None
    """
    snap: TNGClusterSnapshot = args[0]
    zoom_id = kwargs.get("zoomID", None)
    fuzz = kwargs.get("withfuzz", None)
    onlyfuzz = kwargs.get("onlylfuzz", None)
    if zoom_id is None:
        return
    if zoom_id < 0 or zoom_id > (snap.ntargets - 1):
        raise ValueError("zoomID must be in range 0-%i" % (snap.ntargets - 1))

    for p in self.data_backup:
        if p.startswith("PartType"):
            key = "particles"
        elif p == "Group":
            key = "groups"
        elif p == "Subhalo":
            key = "subgroups"
        else:
            continue
        lengths = snap.lengths_zoom[key][zoom_id]
        offsets = snap.offsets_zoom[key][zoom_id]
        length_fuzz = None
        offset_fuzz = None

        if fuzz and key == "particles":  # fuzz files only for particles
            lengths_fuzz = snap.lengths_zoom[key][zoom_id + snap.ntargets]
            offsets_fuzz = snap.offsets_zoom[key][zoom_id + snap.ntargets]

            splt = p.split("PartType")
            pnum = int(splt[1])
            offset_fuzz = offsets_fuzz[pnum]
            length_fuzz = lengths_fuzz[pnum]

        if key == "particles":
            splt = p.split("PartType")
            pnum = int(splt[1])
            offset = offsets[pnum]
            length = lengths[pnum]
        else:
            offset = offsets
            length = lengths

        def get_slicedarr(
            v, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
        ):
            """
            Get a sliced dask array for a given (length, offset) and (length_fuzz, offset_fuzz).

            Parameters
            ----------
            v: da.Array
                The array to slice.
            offset: int
            length: int
            offset_fuzz: int
            length_fuzz: int
            key: str
                ?
            fuzz: bool

            Returns
            -------
            da.Array
                The sliced array.
            """
            arr = v[offset : offset + length]
            if offset_fuzz is not None:
                arr_fuzz = v[offset_fuzz : offset_fuzz + length_fuzz]
                if onlyfuzz:
                    arr = arr_fuzz
                else:
                    arr = np.concatenate([arr, arr_fuzz])
            return arr

        def get_slicedfunc(
            func, offset, length, offset_fuzz, length_fuzz, key, fuzz=False
        ):
            """
            Slice a functions output for a given (length, offset) and (length_fuzz, offset_fuzz).

            Parameters
            ----------
            func: callable
            offset: int
            length: int
            offset_fuzz: int
            length_fuzz: int
            key: str
            fuzz: bool

            Returns
            -------
            callable
                The sliced function.
            """

            def newfunc(
                arrs, o=offset, ln=length, of=offset_fuzz, lnf=length_fuzz, **kwargs
            ):
                arr_all = func(arrs, **kwargs)
                arr = arr_all[o : o + ln]
                if of is None:
                    return arr
                arr_fuzz = arr_all[of : of + lnf]
                if onlyfuzz:
                    return arr_fuzz
                else:
                    return np.concatenate([arr, arr_fuzz])

            return newfunc

        # need to evaluate without recipes first
        for k, v in self.data_backup[p].items(withrecipes=False):
            self.data[p][k] = get_slicedarr(
                v, offset, length, offset_fuzz, length_fuzz, key, fuzz
            )

        for k, v in self.data_backup[p].items(
            withfields=False, withrecipes=True, evaluate=False
        ):
            if not isinstance(v, FieldRecipe):
                continue  # already evaluated, no need to port recipe (?)
            rcp: FieldRecipe = v
            func = get_slicedfunc(
                v.func, offset, length, offset_fuzz, length_fuzz, key, fuzz
            )
            newrcp = DerivedFieldRecipe(rcp.name, func)
            newrcp.type = rcp.type
            newrcp.description = rcp.description
            newrcp.units = rcp.units
            self.data[p][k] = newrcp
    snap.data = self.data
TNGClusterSnapshot

Bases: ArepoSnapshot

Dataset class for the TNG-Cluster simulation.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class TNGClusterSnapshot(ArepoSnapshot):
    """
    Dataset class for the TNG-Cluster simulation.
    """

    _fileprefix_catalog = "fof_subhalo_tab_"
    _fileprefix = "snap_"
    ntargets = 352

    def __init__(self, *args, **kwargs):
        """
        Initialize a TNGClusterSnapshot object.

        Parameters
        ----------
        args
        kwargs
        """
        super().__init__(*args, **kwargs)

        # we can get the offsets from the header as our load routine concats the various values from different files
        # these offsets can be used to select a given halo.
        # see: https://tng-project.org/w/index.php/TNG-Cluster
        # each zoom-target has two entries i and i+N, where N=352 is the number of zoom targets.
        # the first file contains the particles that were contained in the original low-res run
        # the second file contains all other remaining particles in a given zoom target
        def len_to_offsets(lengths):
            """
            From the particle count (field length), get the offset of all zoom targets.

            Parameters
            ----------
            lengths

            Returns
            -------
            np.ndarray
            """
            lengths = np.array(lengths)
            shp = len(lengths.shape)
            n = lengths.shape[-1]
            if shp == 1:
                res = np.concatenate(
                    [np.zeros(1, dtype=np.int64), np.cumsum(lengths.astype(np.int64))]
                )[:-1]
            else:
                res = np.cumsum(
                    np.vstack([np.zeros(n, dtype=np.int64), lengths.astype(np.int64)]),
                    axis=0,
                )[:-1]
            return res

        self.lengths_zoom = dict(particles=self.header["NumPart_ThisFile"])
        self.offsets_zoom = dict(
            particles=len_to_offsets(self.lengths_zoom["particles"])
        )

        if hasattr(self, "catalog") and self.catalog is not None:
            self.lengths_zoom["groups"] = self.catalog.header["Ngroups_ThisFile"]
            self.offsets_zoom["groups"] = len_to_offsets(self.lengths_zoom["groups"])
            self.lengths_zoom["subgroups"] = self.catalog.header["Nsubgroups_ThisFile"]
            self.offsets_zoom["subgroups"] = len_to_offsets(
                self.lengths_zoom["subgroups"]
            )

    @TNGClusterSelector()
    def return_data(self) -> FieldContainer:
        """
        Return the data object.

        Returns
        -------
        FieldContainer
            The data object.
        """
        return super().return_data()

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path as a candidate for TNG-Cluster snapshot class.

        Parameters
        ----------
        path: str
            Path to validate.
        args: list
        kwargs: dict

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this simulation class.
        """

        tkwargs = dict(
            fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
        )
        tkwargs.update(**kwargs)
        valid = super().validate_path(path, *args, **tkwargs)
        if valid == CandidateStatus.NO:
            return valid
        metadata_raw = load_metadata(path, **tkwargs)

        matchingattrs = True

        parameters = metadata_raw["/Parameters"]
        if "InitCondFile" in parameters:
            matchingattrs &= parameters["InitCondFile"] == "various"
        else:
            return CandidateStatus.NO
        header = metadata_raw["/Header"]
        matchingattrs &= header["BoxSize"] == 680000.0
        matchingattrs &= header["NumPart_Total"][1] == 1944529344
        matchingattrs &= header["NumPart_Total"][2] == 586952200

        if matchingattrs:
            valid = CandidateStatus.YES
        else:
            valid = CandidateStatus.NO
        return valid
__init__(*args, **kwargs)

Initialize a TNGClusterSnapshot object.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/customs/arepo/TNGcluster/dataset.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def __init__(self, *args, **kwargs):
    """
    Initialize a TNGClusterSnapshot object.

    Parameters
    ----------
    args
    kwargs
    """
    super().__init__(*args, **kwargs)

    # we can get the offsets from the header as our load routine concats the various values from different files
    # these offsets can be used to select a given halo.
    # see: https://tng-project.org/w/index.php/TNG-Cluster
    # each zoom-target has two entries i and i+N, where N=352 is the number of zoom targets.
    # the first file contains the particles that were contained in the original low-res run
    # the second file contains all other remaining particles in a given zoom target
    def len_to_offsets(lengths):
        """
        From the particle count (field length), get the offset of all zoom targets.

        Parameters
        ----------
        lengths

        Returns
        -------
        np.ndarray
        """
        lengths = np.array(lengths)
        shp = len(lengths.shape)
        n = lengths.shape[-1]
        if shp == 1:
            res = np.concatenate(
                [np.zeros(1, dtype=np.int64), np.cumsum(lengths.astype(np.int64))]
            )[:-1]
        else:
            res = np.cumsum(
                np.vstack([np.zeros(n, dtype=np.int64), lengths.astype(np.int64)]),
                axis=0,
            )[:-1]
        return res

    self.lengths_zoom = dict(particles=self.header["NumPart_ThisFile"])
    self.offsets_zoom = dict(
        particles=len_to_offsets(self.lengths_zoom["particles"])
    )

    if hasattr(self, "catalog") and self.catalog is not None:
        self.lengths_zoom["groups"] = self.catalog.header["Ngroups_ThisFile"]
        self.offsets_zoom["groups"] = len_to_offsets(self.lengths_zoom["groups"])
        self.lengths_zoom["subgroups"] = self.catalog.header["Nsubgroups_ThisFile"]
        self.offsets_zoom["subgroups"] = len_to_offsets(
            self.lengths_zoom["subgroups"]
        )
return_data()

Return the data object.

Returns:

Type Description
FieldContainer

The data object.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
238
239
240
241
242
243
244
245
246
247
248
@TNGClusterSelector()
def return_data(self) -> FieldContainer:
    """
    Return the data object.

    Returns
    -------
    FieldContainer
        The data object.
    """
    return super().return_data()
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for TNG-Cluster snapshot class.

Parameters:

Name Type Description Default
path Union[str, PathLike]

Path to validate.

required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this simulation class.

Source code in src/scida/customs/arepo/TNGcluster/dataset.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path as a candidate for TNG-Cluster snapshot class.

    Parameters
    ----------
    path: str
        Path to validate.
    args: list
    kwargs: dict

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this simulation class.
    """

    tkwargs = dict(
        fileprefix=cls._fileprefix, fileprefix_catalog=cls._fileprefix_catalog
    )
    tkwargs.update(**kwargs)
    valid = super().validate_path(path, *args, **tkwargs)
    if valid == CandidateStatus.NO:
        return valid
    metadata_raw = load_metadata(path, **tkwargs)

    matchingattrs = True

    parameters = metadata_raw["/Parameters"]
    if "InitCondFile" in parameters:
        matchingattrs &= parameters["InitCondFile"] == "various"
    else:
        return CandidateStatus.NO
    header = metadata_raw["/Header"]
    matchingattrs &= header["BoxSize"] == 680000.0
    matchingattrs &= header["NumPart_Total"][1] == 1944529344
    matchingattrs &= header["NumPart_Total"][2] == 586952200

    if matchingattrs:
        valid = CandidateStatus.YES
    else:
        valid = CandidateStatus.NO
    return valid

dataset

ArepoCatalog

Bases: ArepoSnapshot

Dataset class for Arepo group catalogs.

Source code in src/scida/customs/arepo/dataset.py
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
class ArepoCatalog(ArepoSnapshot):
    """
    Dataset class for Arepo group catalogs.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize an ArepoCatalog object.

        Parameters
        ----------
        args
        kwargs
        """
        kwargs["iscatalog"] = True
        if "fileprefix" not in kwargs:
            kwargs["fileprefix"] = "groups"
        super().__init__(*args, **kwargs)

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path to use for instantiation of this class.

        Parameters
        ----------
        path: str or pathlib.Path
        args:
        kwargs:

        Returns
        -------
        CandidateStatus
        """
        kwargs["fileprefix"] = cls._get_fileprefix(path)
        valid = super().validate_path(path, *args, expect_grp=True, **kwargs)
        return valid
__init__(*args, **kwargs)

Initialize an ArepoCatalog object.

Parameters:

Name Type Description Default
args
()
kwargs
{}
Source code in src/scida/customs/arepo/dataset.py
691
692
693
694
695
696
697
698
699
700
701
702
703
def __init__(self, *args, **kwargs):
    """
    Initialize an ArepoCatalog object.

    Parameters
    ----------
    args
    kwargs
    """
    kwargs["iscatalog"] = True
    if "fileprefix" not in kwargs:
        kwargs["fileprefix"] = "groups"
    super().__init__(*args, **kwargs)
validate_path(path, *args, **kwargs) classmethod

Validate a path to use for instantiation of this class.

Parameters:

Name Type Description Default
path Union[str, PathLike]
required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus
Source code in src/scida/customs/arepo/dataset.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path to use for instantiation of this class.

    Parameters
    ----------
    path: str or pathlib.Path
    args:
    kwargs:

    Returns
    -------
    CandidateStatus
    """
    kwargs["fileprefix"] = cls._get_fileprefix(path)
    valid = super().validate_path(path, *args, expect_grp=True, **kwargs)
    return valid
ArepoSnapshot

Bases: SpatialCartesian3DMixin, GadgetStyleSnapshot

Dataset class for Arepo snapshots.

Source code in src/scida/customs/arepo/dataset.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
class ArepoSnapshot(SpatialCartesian3DMixin, GadgetStyleSnapshot):
    """
    Dataset class for Arepo snapshots.
    """

    _fileprefix_catalog = "groups"

    def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
        """
        Initialize an ArepoSnapshot object.

        Parameters
        ----------
        path: str or pathlib.Path
            Path to snapshot, typically a directory containing multiple hdf5 files.
        chunksize:
            Chunksize to use for dask arrays. Can be "auto" to automatically determine chunksize.
        catalog:
            Path to group catalog. If None, the catalog is searched for in the parent directories.
        kwargs:
            Additional keyword arguments.
        """
        self.iscatalog = kwargs.pop("iscatalog", False)
        self.header = {}
        self.config = {}
        self._defaultunitfiles: List[str] = ["units/gadget_cosmological.yaml"]
        self.parameters = {}
        self._grouplengths = {}
        self._subhalolengths = {}
        # not needed for group catalogs as entries are back-to-back there, we will provide a property for this
        self._subhalooffsets = {}
        self.misc = {}  # for storing misc info
        prfx = kwargs.pop("fileprefix", None)
        if prfx is None:
            prfx = self._get_fileprefix(path)
        super().__init__(path, chunksize=chunksize, fileprefix=prfx, **kwargs)

        self.catalog = catalog
        if not self.iscatalog:
            if self.catalog is None:
                self.discover_catalog()
                # try to discover group catalog in parent directories.
            if self.catalog == "none":
                pass  # this string can be set to explicitly disable catalog
            elif self.catalog:
                catalog_cls = kwargs.get("catalog_cls", None)
                cosmological = False
                if hasattr(self, "_mixins") and "cosmology" in self._mixins:
                    cosmological = True
                self.load_catalog(
                    overwrite_cache=kwargs.get("overwrite_cache", False),
                    catalog_cls=catalog_cls,
                    units=self.withunits,
                    cosmological=cosmological,
                )

        # add aliases
        aliases = dict(
            PartType0=["gas", "baryons"],
            PartType1=["dm", "dark matter"],
            PartType2=["lowres", "lowres dm"],
            PartType3=["tracer", "tracers"],
            PartType4=["stars"],
            PartType5=["bh", "black holes"],
        )
        for k, lst in aliases.items():
            if k not in self.data:
                continue
            for v in lst:
                self.data.add_alias(v, k)

        # set metadata
        self._set_metadata()

        # add some default fields
        self.data.merge(fielddefs)

    def load_catalog(
        self, overwrite_cache=False, units=False, cosmological=False, catalog_cls=None
    ):
        """
        Load the group catalog.

        Parameters
        ----------
        kwargs: dict
            Keyword arguments passed to the catalog class.
        catalog_cls: type
            Class to use for the catalog. If None, the default catalog class is used.

        Returns
        -------
        None
        """
        virtualcache = False  # copy catalog for better performance
        # fileprefix = catalog_kwargs.get("fileprefix", self._fileprefix_catalog)
        prfx = self._get_fileprefix(self.catalog)

        # explicitly need to create unitaware class for catalog as needed
        # TODO: should just be determined from mixins of parent?
        if catalog_cls is None:
            cls = ArepoCatalog
        else:
            cls = catalog_cls
        withunits = units
        mixins = []
        if withunits:
            mixins += [UnitMixin]

        other_mixins = _determine_mixins(path=self.path)
        mixins += other_mixins
        if cosmological and CosmologyMixin not in mixins:
            mixins.append(CosmologyMixin)

        cls = create_datasetclass_with_mixins(cls, mixins)

        ureg = None
        if hasattr(self, "ureg"):
            ureg = self.ureg

        self.catalog = cls(
            self.catalog,
            overwrite_cache=overwrite_cache,
            virtualcache=virtualcache,
            fileprefix=prfx,
            units=self.withunits,
            ureg=ureg,
        )
        if "Redshift" in self.catalog.header and "Redshift" in self.header:
            z_catalog = self.catalog.header["Redshift"]
            z_snap = self.header["Redshift"]
            if not np.isclose(z_catalog, z_snap):
                raise ValueError(
                    "Redshift mismatch between snapshot and catalog: "
                    f"{z_snap:.2f} vs {z_catalog:.2f}"
                )

        # merge data
        self.merge_data(self.catalog)

        # first snapshots often do not have groups
        if "Group" in self.catalog.data:
            ngkeys = self.catalog.data["Group"].keys()
            if len(ngkeys) > 0:
                self.add_catalogIDs()

        # merge hints from snap and catalog
        self.merge_hints(self.catalog)

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, **kwargs
    ) -> CandidateStatus:
        """
        Validate a path to use for instantiation of this class.

        Parameters
        ----------
        path: str or pathlib.Path
        args:
        kwargs:

        Returns
        -------
        CandidateStatus
        """
        valid = super().validate_path(path, *args, **kwargs)
        if valid.value > CandidateStatus.MAYBE.value:
            valid = CandidateStatus.MAYBE
        else:
            return valid
        # Arepo has no dedicated attribute to identify such runs.
        # lets just query a bunch of attributes that are present for arepo runs
        metadata_raw = load_metadata(path, **kwargs)
        matchingattrs = True
        matchingattrs &= "Git_commit" in metadata_raw["/Header"]
        # not existent for any arepo run?
        matchingattrs &= "Compactify_Version" not in metadata_raw["/Header"]

        if matchingattrs:
            valid = CandidateStatus.MAYBE

        return valid

    @ArepoSelector()
    def return_data(self) -> FieldContainer:
        """
        Return data object of this snapshot.

        Returns
        -------
        None
        """
        return super().return_data()

    def discover_catalog(self):
        """
        Discover the group catalog given the current path

        Returns
        -------
        None
        """
        p = str(self.path)
        # order of candidates matters. For Illustris "groups" must precede "fof_subhalo_tab"
        candidates = [
            p.replace("snapshot", "group"),
            p.replace("snapshot", "groups"),
            p.replace("snapdir", "groups").replace("snap", "groups"),
            p.replace("snapdir", "groups").replace("snap", "fof_subhalo_tab"),
        ]
        for candidate in candidates:
            if not os.path.exists(candidate):
                continue
            if candidate == self.path:
                continue
            self.catalog = candidate
            break

    def register_field(self, parttype: str, name: str = None, construct: bool = False):
        """
        Register a field.
        Parameters
        ----------
        parttype: str
            name of particle type
        name: str
            name of field
        construct: bool
            construct field immediately

        Returns
        -------
        None
        """
        num = part_type_num(parttype)
        if construct:  # TODO: introduce (immediate) construct option later
            raise NotImplementedError
        if num == -1:  # TODO: all particle species
            key = "all"
            raise NotImplementedError
        elif isinstance(num, int):
            key = "PartType" + str(num)
        else:
            key = parttype
        return super().register_field(key, name=name)

    def add_catalogIDs(self) -> None:
        """
        Add field for halo and subgroup IDs for all particle types.

        Returns
        -------
        None
        """
        # TODO: make these delayed objects and properly pass into (delayed?) numba functions:
        # https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-repeatedly-putting-large-inputs-into-delayed-calls

        maxint = np.iinfo(np.int64).max
        self.misc["unboundID"] = maxint

        # Group ID
        if "Group" not in self.data:  # can happen for empty catalogs
            for key in self.data:
                if not (key.startswith("PartType")):
                    continue
                uid = self.data[key]["uid"]
                self.data[key]["GroupID"] = self.misc["unboundID"] * da.ones_like(
                    uid, dtype=np.int64
                )
                self.data[key]["SubhaloID"] = self.misc["unboundID"] * da.ones_like(
                    uid, dtype=np.int64
                )
            return

        glen = self.data["Group"]["GroupLenType"]
        ngrp = glen.shape[0]
        da_halocelloffsets = da.concatenate(
            [
                np.zeros((1, 6), dtype=np.int64),
                da.cumsum(glen, axis=0, dtype=np.int64),
            ]
        )
        # remove last entry to match shapematch shape
        self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
            glen.chunks
        )
        halocelloffsets = da_halocelloffsets.rechunk(-1)

        index_unbound = self.misc["unboundID"]

        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            num = int(key[-1])
            if "uid" not in self.data[key]:
                continue  # can happen for empty containers
            gidx = self.data[key]["uid"]
            hidx = compute_haloindex(
                gidx, halocelloffsets[:, num], index_unbound=index_unbound
            )
            self.data[key]["GroupID"] = hidx

        # Subhalo ID
        if "Subhalo" not in self.data:  # can happen for empty catalogs
            for key in self.data:
                if not (key.startswith("PartType")):
                    continue
                self.data[key]["SubhaloID"] = -1 * da.ones_like(
                    da[key]["uid"], dtype=np.int64
                )
            return

        shnr_attr = "SubhaloGrNr"
        if shnr_attr not in self.data["Subhalo"]:
            shnr_attr = "SubhaloGroupNr"  # what MTNG does
        if shnr_attr not in self.data["Subhalo"]:
            raise ValueError(
                f"Could not find 'SubhaloGrNr' or 'SubhaloGroupNr' in {self.catalog}"
            )

        subhalogrnr = self.data["Subhalo"][shnr_attr]
        subhalocellcounts = self.data["Subhalo"]["SubhaloLenType"]

        # remove "units" for numba funcs
        if hasattr(subhalogrnr, "magnitude"):
            subhalogrnr = subhalogrnr.magnitude
        if hasattr(subhalocellcounts, "magnitude"):
            subhalocellcounts = subhalocellcounts.magnitude

        grp = self.data["Group"]
        if "GroupFirstSub" not in grp or "GroupNsubs" not in grp:
            # if not provided, we calculate:
            # "GroupFirstSub": First subhalo index for each halo
            # "GroupNsubs": Number of subhalos for each halo
            dlyd = delayed(get_shcounts_shcells)(subhalogrnr, ngrp)
            grp["GroupFirstSub"] = dask.compute(dlyd[1])[0]
            grp["GroupNsubs"] = dask.compute(dlyd[0])[0]

        # remove "units" for numba funcs
        grpfirstsub = grp["GroupFirstSub"]
        if hasattr(grpfirstsub, "magnitude"):
            grpfirstsub = grpfirstsub.magnitude
        grpnsubs = grp["GroupNsubs"]
        if hasattr(grpnsubs, "magnitude"):
            grpnsubs = grpnsubs.magnitude

        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            num = int(key[-1])
            pdata = self.data[key]
            if "uid" not in self.data[key]:
                continue  # can happen for empty containers
            gidx = pdata["uid"]

            # we need to make other dask arrays delayed,
            # map_block does not incorrectly infer output shape from these
            halocelloffsets_dlyd = delayed(halocelloffsets[:, num])
            grpfirstsub_dlyd = delayed(grpfirstsub)
            grpnsubs_dlyd = delayed(grpnsubs)
            subhalocellcounts_dlyd = delayed(subhalocellcounts[:, num])

            sidx = compute_localsubhaloindex(
                gidx,
                halocelloffsets_dlyd,
                grpfirstsub_dlyd,
                grpnsubs_dlyd,
                subhalocellcounts_dlyd,
                index_unbound=index_unbound,
            )

            pdata["LocalSubhaloID"] = sidx

            # reconstruct SubhaloID from Group's GroupFirstSub and LocalSubhaloID
            # should be easier to do it directly, but quicker to write down like this:

            # calculate first subhalo of each halo that a particle belongs to
            self.add_groupquantity_to_particles("GroupFirstSub", parttype=key)
            pdata["SubhaloID"] = pdata["GroupFirstSub"] + pdata["LocalSubhaloID"]
            pdata["SubhaloID"] = da.where(
                pdata["SubhaloID"] == index_unbound, index_unbound, pdata["SubhaloID"]
            )

    @computedecorator
    def map_group_operation(
        self,
        func,
        cpucost_halo=1e4,
        nchunks_min=None,
        chunksize_bytes=None,
        nmax=None,
        idxlist=None,
        objtype="halo",
    ):
        """
        Apply a function to each halo in the catalog.

        Parameters
        ----------
        objtype: str
            Type of object to process. Can be "halo" or "subhalo". Default: "halo"
        idxlist: Optional[np.ndarray]
            List of halo indices to process. If not provided, all halos are processed.
        func: function
            Function to apply to each halo. Must take a dictionary of arrays as input.
        cpucost_halo:
            "CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
            used for calculating the dask chunks. Default: 1e4
        nchunks_min: Optional[int]
            Minimum number of particles in a halo to process it. Default: None
        chunksize_bytes: Optional[int]
        nmax: Optional[int]
            Only process the first nmax halos.

        Returns
        -------
        None
        """
        dfltkwargs = get_kwargs(func)
        fieldnames = dfltkwargs.get("fieldnames", None)
        if fieldnames is None:
            fieldnames = get_args(func)
        parttype = dfltkwargs.get("parttype", "PartType0")
        entry_nbytes_in = np.sum([self.data[parttype][f][0].nbytes for f in fieldnames])
        objtype = grp_type_str(objtype)
        if objtype == "halo":
            lengths = self.get_grouplengths(parttype=parttype)
            offsets = self.get_groupoffsets(parttype=parttype)
        elif objtype == "subhalo":
            lengths = self.get_subhalolengths(parttype=parttype)
            offsets = self.get_subhalooffsets(parttype=parttype)
        else:
            raise ValueError(f"objtype must be 'halo' or 'subhalo', not {objtype}")
        arrdict = self.data[parttype]
        return map_group_operation(
            func,
            offsets,
            lengths,
            arrdict,
            cpucost_halo=cpucost_halo,
            nchunks_min=nchunks_min,
            chunksize_bytes=chunksize_bytes,
            entry_nbytes_in=entry_nbytes_in,
            nmax=nmax,
            idxlist=idxlist,
        )

    def add_groupquantity_to_particles(self, name, parttype="PartType0"):
        """
        Map a quantity from the group catalog to the particles based on a particle's group index.

        Parameters
        ----------
        name: str
            Name of quantity to map
        parttype: str
            Name of particle type

        Returns
        -------
        None
        """
        pdata = self.data[parttype]
        assert (
            name not in pdata
        )  # we simply map the name from Group to Particle for now. Should work (?)
        glen = self.data["Group"]["GroupLenType"]
        da_halocelloffsets = da.concatenate(
            [np.zeros((1, 6), dtype=np.int64), da.cumsum(glen, axis=0)]
        )
        if "GroupOffsetsType" not in self.data["Group"]:
            self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
                glen.chunks
            )  # remove last entry to match shape
        halocelloffsets = da_halocelloffsets.compute()

        gidx = pdata["uid"]
        num = int(parttype[-1])
        hquantity = compute_haloquantity(
            gidx, halocelloffsets[:, num], self.data["Group"][name]
        )
        pdata[name] = hquantity

    def get_grouplengths(self, parttype="PartType0"):
        """
        Get the lengths, i.e. the total number of particles, of a given type in all halos.

        Parameters
        ----------
        parttype: str
            Name of particle type

        Returns
        -------
        np.ndarray
        """
        pnum = part_type_num(parttype)
        ptype = "PartType%i" % pnum
        if ptype not in self._grouplengths:
            lengths = self.data["Group"]["GroupLenType"][:, pnum].compute()
            if isinstance(lengths, pint.Quantity):
                lengths = lengths.magnitude
            self._grouplengths[ptype] = lengths
        return self._grouplengths[ptype]

    def get_groupoffsets(self, parttype="PartType0"):
        """
        Get the array index offset of the first particle of a given type in each halo.

        Parameters
        ----------
        parttype: str
            Name of particle type

        Returns
        -------
        np.ndarray
        """
        if parttype not in self._grouplengths:
            # need to calculate group lengths first
            self.get_grouplengths(parttype=parttype)
        return self._groupoffsets[parttype]

    @property
    def _groupoffsets(self):
        lengths = self._grouplengths
        offsets = {
            k: np.concatenate([[0], np.cumsum(v)[:-1]]) for k, v in lengths.items()
        }
        return offsets

    def get_subhalolengths(self, parttype="PartType0"):
        """
        Get the lengths, i.e. the total number of particles, of a given type in all subhalos.

        Parameters
        ----------
        parttype: str

        Returns
        -------
        np.ndarray
        """
        pnum = part_type_num(parttype)
        ptype = "PartType%i" % pnum
        if ptype in self._subhalolengths:
            return self._subhalolengths[ptype]
        lengths = self.data["Subhalo"]["SubhaloLenType"][:, pnum].compute()
        if isinstance(lengths, pint.Quantity):
            lengths = lengths.magnitude
        self._subhalolengths[ptype] = lengths
        return self._subhalolengths[ptype]

    def get_subhalooffsets(self, parttype="PartType0"):
        """
        Get the array index offset of the first particle of a given type in each subhalo.

        Parameters
        ----------
        parttype: str

        Returns
        -------
        np.ndarray
        """

        pnum = part_type_num(parttype)
        ptype = "PartType%i" % pnum
        if ptype in self._subhalooffsets:
            return self._subhalooffsets[ptype]  # use cached result
        goffsets = self.get_groupoffsets(ptype)
        shgrnr = self.data["Subhalo"]["SubhaloGrNr"]
        # calculate the index of the first particle for the central subhalo of each subhalos's parent halo
        shoffset_central = goffsets[shgrnr]

        grpfirstsub = self.data["Group"]["GroupFirstSub"]
        shlens = self.get_subhalolengths(ptype)
        shoffsets = np.concatenate([[0], np.cumsum(shlens)[:-1]])

        # particle offset for the first subhalo of each group that a subhalo belongs to
        shfirstshoffset = shoffsets[grpfirstsub[shgrnr]]

        # "LocalSubhaloOffset": particle offset of each subhalo in the parent group
        shoffset_local = shoffsets - shfirstshoffset

        # "SubhaloOffset": particle offset of each subhalo in the simulation
        offsets = shoffset_central + shoffset_local

        self._subhalooffsets[ptype] = offsets

        return offsets

    def grouped(
        self,
        fields: Union[str, da.Array, List[str], Dict[str, da.Array]] = "",
        parttype="PartType0",
        objtype="halo",
    ):
        """
        Create a GroupAwareOperation object for applying operations to groups.

        Parameters
        ----------
        fields: Union[str, da.Array, List[str], Dict[str, da.Array]]
            Fields to pass to the operation. Can be a string, a dask array, a list of strings or a dictionary of dask arrays.
        parttype: str
            Particle type to operate on.
        objtype: str
            Type of object to operate on. Can be "halo" or "subhalo". Default: "halo"

        Returns
        -------
        GroupAwareOperation
        """
        inputfields = None
        if isinstance(fields, str):
            if fields == "":  # if nothing is specified, we pass all we have.
                arrdict = self.data[parttype]
            else:
                arrdict = dict(field=self.data[parttype][fields])
                inputfields = [fields]
        elif isinstance(fields, da.Array) or isinstance(fields, pint.Quantity):
            arrdict = dict(daskarr=fields)
            inputfields = [fields.name]
        elif isinstance(fields, list):
            arrdict = {k: self.data[parttype][k] for k in fields}
            inputfields = fields
        elif isinstance(fields, dict):
            arrdict = {}
            arrdict.update(**fields)
            inputfields = list(arrdict.keys())
        else:
            raise ValueError("Unknown input type '%s'." % type(fields))
        objtype = grp_type_str(objtype)
        if objtype == "halo":
            offsets = self.get_groupoffsets(parttype=parttype)
            lengths = self.get_grouplengths(parttype=parttype)
        elif objtype == "subhalo":
            offsets = self.get_subhalooffsets(parttype=parttype)
            lengths = self.get_subhalolengths(parttype=parttype)
        else:
            raise ValueError("Unknown object type '%s'." % objtype)

        gop = GroupAwareOperation(
            offsets,
            lengths,
            arrdict,
            inputfields=inputfields,
        )
        return gop
__init__(path, chunksize='auto', catalog=None, **kwargs)

Initialize an ArepoSnapshot object.

Parameters:

Name Type Description Default
path

Path to snapshot, typically a directory containing multiple hdf5 files.

required
chunksize

Chunksize to use for dask arrays. Can be "auto" to automatically determine chunksize.

'auto'
catalog

Path to group catalog. If None, the catalog is searched for in the parent directories.

None
kwargs

Additional keyword arguments.

{}
Source code in src/scida/customs/arepo/dataset.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(self, path, chunksize="auto", catalog=None, **kwargs) -> None:
    """
    Initialize an ArepoSnapshot object.

    Parameters
    ----------
    path: str or pathlib.Path
        Path to snapshot, typically a directory containing multiple hdf5 files.
    chunksize:
        Chunksize to use for dask arrays. Can be "auto" to automatically determine chunksize.
    catalog:
        Path to group catalog. If None, the catalog is searched for in the parent directories.
    kwargs:
        Additional keyword arguments.
    """
    self.iscatalog = kwargs.pop("iscatalog", False)
    self.header = {}
    self.config = {}
    self._defaultunitfiles: List[str] = ["units/gadget_cosmological.yaml"]
    self.parameters = {}
    self._grouplengths = {}
    self._subhalolengths = {}
    # not needed for group catalogs as entries are back-to-back there, we will provide a property for this
    self._subhalooffsets = {}
    self.misc = {}  # for storing misc info
    prfx = kwargs.pop("fileprefix", None)
    if prfx is None:
        prfx = self._get_fileprefix(path)
    super().__init__(path, chunksize=chunksize, fileprefix=prfx, **kwargs)

    self.catalog = catalog
    if not self.iscatalog:
        if self.catalog is None:
            self.discover_catalog()
            # try to discover group catalog in parent directories.
        if self.catalog == "none":
            pass  # this string can be set to explicitly disable catalog
        elif self.catalog:
            catalog_cls = kwargs.get("catalog_cls", None)
            cosmological = False
            if hasattr(self, "_mixins") and "cosmology" in self._mixins:
                cosmological = True
            self.load_catalog(
                overwrite_cache=kwargs.get("overwrite_cache", False),
                catalog_cls=catalog_cls,
                units=self.withunits,
                cosmological=cosmological,
            )

    # add aliases
    aliases = dict(
        PartType0=["gas", "baryons"],
        PartType1=["dm", "dark matter"],
        PartType2=["lowres", "lowres dm"],
        PartType3=["tracer", "tracers"],
        PartType4=["stars"],
        PartType5=["bh", "black holes"],
    )
    for k, lst in aliases.items():
        if k not in self.data:
            continue
        for v in lst:
            self.data.add_alias(v, k)

    # set metadata
    self._set_metadata()

    # add some default fields
    self.data.merge(fielddefs)
add_catalogIDs()

Add field for halo and subgroup IDs for all particle types.

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def add_catalogIDs(self) -> None:
    """
    Add field for halo and subgroup IDs for all particle types.

    Returns
    -------
    None
    """
    # TODO: make these delayed objects and properly pass into (delayed?) numba functions:
    # https://docs.dask.org/en/stable/delayed-best-practices.html#avoid-repeatedly-putting-large-inputs-into-delayed-calls

    maxint = np.iinfo(np.int64).max
    self.misc["unboundID"] = maxint

    # Group ID
    if "Group" not in self.data:  # can happen for empty catalogs
        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            uid = self.data[key]["uid"]
            self.data[key]["GroupID"] = self.misc["unboundID"] * da.ones_like(
                uid, dtype=np.int64
            )
            self.data[key]["SubhaloID"] = self.misc["unboundID"] * da.ones_like(
                uid, dtype=np.int64
            )
        return

    glen = self.data["Group"]["GroupLenType"]
    ngrp = glen.shape[0]
    da_halocelloffsets = da.concatenate(
        [
            np.zeros((1, 6), dtype=np.int64),
            da.cumsum(glen, axis=0, dtype=np.int64),
        ]
    )
    # remove last entry to match shapematch shape
    self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
        glen.chunks
    )
    halocelloffsets = da_halocelloffsets.rechunk(-1)

    index_unbound = self.misc["unboundID"]

    for key in self.data:
        if not (key.startswith("PartType")):
            continue
        num = int(key[-1])
        if "uid" not in self.data[key]:
            continue  # can happen for empty containers
        gidx = self.data[key]["uid"]
        hidx = compute_haloindex(
            gidx, halocelloffsets[:, num], index_unbound=index_unbound
        )
        self.data[key]["GroupID"] = hidx

    # Subhalo ID
    if "Subhalo" not in self.data:  # can happen for empty catalogs
        for key in self.data:
            if not (key.startswith("PartType")):
                continue
            self.data[key]["SubhaloID"] = -1 * da.ones_like(
                da[key]["uid"], dtype=np.int64
            )
        return

    shnr_attr = "SubhaloGrNr"
    if shnr_attr not in self.data["Subhalo"]:
        shnr_attr = "SubhaloGroupNr"  # what MTNG does
    if shnr_attr not in self.data["Subhalo"]:
        raise ValueError(
            f"Could not find 'SubhaloGrNr' or 'SubhaloGroupNr' in {self.catalog}"
        )

    subhalogrnr = self.data["Subhalo"][shnr_attr]
    subhalocellcounts = self.data["Subhalo"]["SubhaloLenType"]

    # remove "units" for numba funcs
    if hasattr(subhalogrnr, "magnitude"):
        subhalogrnr = subhalogrnr.magnitude
    if hasattr(subhalocellcounts, "magnitude"):
        subhalocellcounts = subhalocellcounts.magnitude

    grp = self.data["Group"]
    if "GroupFirstSub" not in grp or "GroupNsubs" not in grp:
        # if not provided, we calculate:
        # "GroupFirstSub": First subhalo index for each halo
        # "GroupNsubs": Number of subhalos for each halo
        dlyd = delayed(get_shcounts_shcells)(subhalogrnr, ngrp)
        grp["GroupFirstSub"] = dask.compute(dlyd[1])[0]
        grp["GroupNsubs"] = dask.compute(dlyd[0])[0]

    # remove "units" for numba funcs
    grpfirstsub = grp["GroupFirstSub"]
    if hasattr(grpfirstsub, "magnitude"):
        grpfirstsub = grpfirstsub.magnitude
    grpnsubs = grp["GroupNsubs"]
    if hasattr(grpnsubs, "magnitude"):
        grpnsubs = grpnsubs.magnitude

    for key in self.data:
        if not (key.startswith("PartType")):
            continue
        num = int(key[-1])
        pdata = self.data[key]
        if "uid" not in self.data[key]:
            continue  # can happen for empty containers
        gidx = pdata["uid"]

        # we need to make other dask arrays delayed,
        # map_block does not incorrectly infer output shape from these
        halocelloffsets_dlyd = delayed(halocelloffsets[:, num])
        grpfirstsub_dlyd = delayed(grpfirstsub)
        grpnsubs_dlyd = delayed(grpnsubs)
        subhalocellcounts_dlyd = delayed(subhalocellcounts[:, num])

        sidx = compute_localsubhaloindex(
            gidx,
            halocelloffsets_dlyd,
            grpfirstsub_dlyd,
            grpnsubs_dlyd,
            subhalocellcounts_dlyd,
            index_unbound=index_unbound,
        )

        pdata["LocalSubhaloID"] = sidx

        # reconstruct SubhaloID from Group's GroupFirstSub and LocalSubhaloID
        # should be easier to do it directly, but quicker to write down like this:

        # calculate first subhalo of each halo that a particle belongs to
        self.add_groupquantity_to_particles("GroupFirstSub", parttype=key)
        pdata["SubhaloID"] = pdata["GroupFirstSub"] + pdata["LocalSubhaloID"]
        pdata["SubhaloID"] = da.where(
            pdata["SubhaloID"] == index_unbound, index_unbound, pdata["SubhaloID"]
        )
add_groupquantity_to_particles(name, parttype='PartType0')

Map a quantity from the group catalog to the particles based on a particle's group index.

Parameters:

Name Type Description Default
name

Name of quantity to map

required
parttype

Name of particle type

'PartType0'

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def add_groupquantity_to_particles(self, name, parttype="PartType0"):
    """
    Map a quantity from the group catalog to the particles based on a particle's group index.

    Parameters
    ----------
    name: str
        Name of quantity to map
    parttype: str
        Name of particle type

    Returns
    -------
    None
    """
    pdata = self.data[parttype]
    assert (
        name not in pdata
    )  # we simply map the name from Group to Particle for now. Should work (?)
    glen = self.data["Group"]["GroupLenType"]
    da_halocelloffsets = da.concatenate(
        [np.zeros((1, 6), dtype=np.int64), da.cumsum(glen, axis=0)]
    )
    if "GroupOffsetsType" not in self.data["Group"]:
        self.data["Group"]["GroupOffsetsType"] = da_halocelloffsets[:-1].rechunk(
            glen.chunks
        )  # remove last entry to match shape
    halocelloffsets = da_halocelloffsets.compute()

    gidx = pdata["uid"]
    num = int(parttype[-1])
    hquantity = compute_haloquantity(
        gidx, halocelloffsets[:, num], self.data["Group"][name]
    )
    pdata[name] = hquantity
discover_catalog()

Discover the group catalog given the current path

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def discover_catalog(self):
    """
    Discover the group catalog given the current path

    Returns
    -------
    None
    """
    p = str(self.path)
    # order of candidates matters. For Illustris "groups" must precede "fof_subhalo_tab"
    candidates = [
        p.replace("snapshot", "group"),
        p.replace("snapshot", "groups"),
        p.replace("snapdir", "groups").replace("snap", "groups"),
        p.replace("snapdir", "groups").replace("snap", "fof_subhalo_tab"),
    ]
    for candidate in candidates:
        if not os.path.exists(candidate):
            continue
        if candidate == self.path:
            continue
        self.catalog = candidate
        break
get_grouplengths(parttype='PartType0')

Get the lengths, i.e. the total number of particles, of a given type in all halos.

Parameters:

Name Type Description Default
parttype

Name of particle type

'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def get_grouplengths(self, parttype="PartType0"):
    """
    Get the lengths, i.e. the total number of particles, of a given type in all halos.

    Parameters
    ----------
    parttype: str
        Name of particle type

    Returns
    -------
    np.ndarray
    """
    pnum = part_type_num(parttype)
    ptype = "PartType%i" % pnum
    if ptype not in self._grouplengths:
        lengths = self.data["Group"]["GroupLenType"][:, pnum].compute()
        if isinstance(lengths, pint.Quantity):
            lengths = lengths.magnitude
        self._grouplengths[ptype] = lengths
    return self._grouplengths[ptype]
get_groupoffsets(parttype='PartType0')

Get the array index offset of the first particle of a given type in each halo.

Parameters:

Name Type Description Default
parttype

Name of particle type

'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def get_groupoffsets(self, parttype="PartType0"):
    """
    Get the array index offset of the first particle of a given type in each halo.

    Parameters
    ----------
    parttype: str
        Name of particle type

    Returns
    -------
    np.ndarray
    """
    if parttype not in self._grouplengths:
        # need to calculate group lengths first
        self.get_grouplengths(parttype=parttype)
    return self._groupoffsets[parttype]
get_subhalolengths(parttype='PartType0')

Get the lengths, i.e. the total number of particles, of a given type in all subhalos.

Parameters:

Name Type Description Default
parttype
'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def get_subhalolengths(self, parttype="PartType0"):
    """
    Get the lengths, i.e. the total number of particles, of a given type in all subhalos.

    Parameters
    ----------
    parttype: str

    Returns
    -------
    np.ndarray
    """
    pnum = part_type_num(parttype)
    ptype = "PartType%i" % pnum
    if ptype in self._subhalolengths:
        return self._subhalolengths[ptype]
    lengths = self.data["Subhalo"]["SubhaloLenType"][:, pnum].compute()
    if isinstance(lengths, pint.Quantity):
        lengths = lengths.magnitude
    self._subhalolengths[ptype] = lengths
    return self._subhalolengths[ptype]
get_subhalooffsets(parttype='PartType0')

Get the array index offset of the first particle of a given type in each subhalo.

Parameters:

Name Type Description Default
parttype
'PartType0'

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def get_subhalooffsets(self, parttype="PartType0"):
    """
    Get the array index offset of the first particle of a given type in each subhalo.

    Parameters
    ----------
    parttype: str

    Returns
    -------
    np.ndarray
    """

    pnum = part_type_num(parttype)
    ptype = "PartType%i" % pnum
    if ptype in self._subhalooffsets:
        return self._subhalooffsets[ptype]  # use cached result
    goffsets = self.get_groupoffsets(ptype)
    shgrnr = self.data["Subhalo"]["SubhaloGrNr"]
    # calculate the index of the first particle for the central subhalo of each subhalos's parent halo
    shoffset_central = goffsets[shgrnr]

    grpfirstsub = self.data["Group"]["GroupFirstSub"]
    shlens = self.get_subhalolengths(ptype)
    shoffsets = np.concatenate([[0], np.cumsum(shlens)[:-1]])

    # particle offset for the first subhalo of each group that a subhalo belongs to
    shfirstshoffset = shoffsets[grpfirstsub[shgrnr]]

    # "LocalSubhaloOffset": particle offset of each subhalo in the parent group
    shoffset_local = shoffsets - shfirstshoffset

    # "SubhaloOffset": particle offset of each subhalo in the simulation
    offsets = shoffset_central + shoffset_local

    self._subhalooffsets[ptype] = offsets

    return offsets
grouped(fields='', parttype='PartType0', objtype='halo')

Create a GroupAwareOperation object for applying operations to groups.

Parameters:

Name Type Description Default
fields Union[str, Array, List[str], Dict[str, Array]]

Fields to pass to the operation. Can be a string, a dask array, a list of strings or a dictionary of dask arrays.

''
parttype

Particle type to operate on.

'PartType0'
objtype

Type of object to operate on. Can be "halo" or "subhalo". Default: "halo"

'halo'

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def grouped(
    self,
    fields: Union[str, da.Array, List[str], Dict[str, da.Array]] = "",
    parttype="PartType0",
    objtype="halo",
):
    """
    Create a GroupAwareOperation object for applying operations to groups.

    Parameters
    ----------
    fields: Union[str, da.Array, List[str], Dict[str, da.Array]]
        Fields to pass to the operation. Can be a string, a dask array, a list of strings or a dictionary of dask arrays.
    parttype: str
        Particle type to operate on.
    objtype: str
        Type of object to operate on. Can be "halo" or "subhalo". Default: "halo"

    Returns
    -------
    GroupAwareOperation
    """
    inputfields = None
    if isinstance(fields, str):
        if fields == "":  # if nothing is specified, we pass all we have.
            arrdict = self.data[parttype]
        else:
            arrdict = dict(field=self.data[parttype][fields])
            inputfields = [fields]
    elif isinstance(fields, da.Array) or isinstance(fields, pint.Quantity):
        arrdict = dict(daskarr=fields)
        inputfields = [fields.name]
    elif isinstance(fields, list):
        arrdict = {k: self.data[parttype][k] for k in fields}
        inputfields = fields
    elif isinstance(fields, dict):
        arrdict = {}
        arrdict.update(**fields)
        inputfields = list(arrdict.keys())
    else:
        raise ValueError("Unknown input type '%s'." % type(fields))
    objtype = grp_type_str(objtype)
    if objtype == "halo":
        offsets = self.get_groupoffsets(parttype=parttype)
        lengths = self.get_grouplengths(parttype=parttype)
    elif objtype == "subhalo":
        offsets = self.get_subhalooffsets(parttype=parttype)
        lengths = self.get_subhalolengths(parttype=parttype)
    else:
        raise ValueError("Unknown object type '%s'." % objtype)

    gop = GroupAwareOperation(
        offsets,
        lengths,
        arrdict,
        inputfields=inputfields,
    )
    return gop
load_catalog(overwrite_cache=False, units=False, cosmological=False, catalog_cls=None)

Load the group catalog.

Parameters:

Name Type Description Default
kwargs

Keyword arguments passed to the catalog class.

required
catalog_cls

Class to use for the catalog. If None, the default catalog class is used.

None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def load_catalog(
    self, overwrite_cache=False, units=False, cosmological=False, catalog_cls=None
):
    """
    Load the group catalog.

    Parameters
    ----------
    kwargs: dict
        Keyword arguments passed to the catalog class.
    catalog_cls: type
        Class to use for the catalog. If None, the default catalog class is used.

    Returns
    -------
    None
    """
    virtualcache = False  # copy catalog for better performance
    # fileprefix = catalog_kwargs.get("fileprefix", self._fileprefix_catalog)
    prfx = self._get_fileprefix(self.catalog)

    # explicitly need to create unitaware class for catalog as needed
    # TODO: should just be determined from mixins of parent?
    if catalog_cls is None:
        cls = ArepoCatalog
    else:
        cls = catalog_cls
    withunits = units
    mixins = []
    if withunits:
        mixins += [UnitMixin]

    other_mixins = _determine_mixins(path=self.path)
    mixins += other_mixins
    if cosmological and CosmologyMixin not in mixins:
        mixins.append(CosmologyMixin)

    cls = create_datasetclass_with_mixins(cls, mixins)

    ureg = None
    if hasattr(self, "ureg"):
        ureg = self.ureg

    self.catalog = cls(
        self.catalog,
        overwrite_cache=overwrite_cache,
        virtualcache=virtualcache,
        fileprefix=prfx,
        units=self.withunits,
        ureg=ureg,
    )
    if "Redshift" in self.catalog.header and "Redshift" in self.header:
        z_catalog = self.catalog.header["Redshift"]
        z_snap = self.header["Redshift"]
        if not np.isclose(z_catalog, z_snap):
            raise ValueError(
                "Redshift mismatch between snapshot and catalog: "
                f"{z_snap:.2f} vs {z_catalog:.2f}"
            )

    # merge data
    self.merge_data(self.catalog)

    # first snapshots often do not have groups
    if "Group" in self.catalog.data:
        ngkeys = self.catalog.data["Group"].keys()
        if len(ngkeys) > 0:
            self.add_catalogIDs()

    # merge hints from snap and catalog
    self.merge_hints(self.catalog)
map_group_operation(func, cpucost_halo=10000.0, nchunks_min=None, chunksize_bytes=None, nmax=None, idxlist=None, objtype='halo')

Apply a function to each halo in the catalog.

Parameters:

Name Type Description Default
objtype

Type of object to process. Can be "halo" or "subhalo". Default: "halo"

'halo'
idxlist

List of halo indices to process. If not provided, all halos are processed.

None
func

Function to apply to each halo. Must take a dictionary of arrays as input.

required
cpucost_halo

"CPU cost" of processing a single halo. This is a relative value to the processing time per input particle used for calculating the dask chunks. Default: 1e4

10000.0
nchunks_min

Minimum number of particles in a halo to process it. Default: None

None
chunksize_bytes
None
nmax

Only process the first nmax halos.

None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
@computedecorator
def map_group_operation(
    self,
    func,
    cpucost_halo=1e4,
    nchunks_min=None,
    chunksize_bytes=None,
    nmax=None,
    idxlist=None,
    objtype="halo",
):
    """
    Apply a function to each halo in the catalog.

    Parameters
    ----------
    objtype: str
        Type of object to process. Can be "halo" or "subhalo". Default: "halo"
    idxlist: Optional[np.ndarray]
        List of halo indices to process. If not provided, all halos are processed.
    func: function
        Function to apply to each halo. Must take a dictionary of arrays as input.
    cpucost_halo:
        "CPU cost" of processing a single halo. This is a relative value to the processing time per input particle
        used for calculating the dask chunks. Default: 1e4
    nchunks_min: Optional[int]
        Minimum number of particles in a halo to process it. Default: None
    chunksize_bytes: Optional[int]
    nmax: Optional[int]
        Only process the first nmax halos.

    Returns
    -------
    None
    """
    dfltkwargs = get_kwargs(func)
    fieldnames = dfltkwargs.get("fieldnames", None)
    if fieldnames is None:
        fieldnames = get_args(func)
    parttype = dfltkwargs.get("parttype", "PartType0")
    entry_nbytes_in = np.sum([self.data[parttype][f][0].nbytes for f in fieldnames])
    objtype = grp_type_str(objtype)
    if objtype == "halo":
        lengths = self.get_grouplengths(parttype=parttype)
        offsets = self.get_groupoffsets(parttype=parttype)
    elif objtype == "subhalo":
        lengths = self.get_subhalolengths(parttype=parttype)
        offsets = self.get_subhalooffsets(parttype=parttype)
    else:
        raise ValueError(f"objtype must be 'halo' or 'subhalo', not {objtype}")
    arrdict = self.data[parttype]
    return map_group_operation(
        func,
        offsets,
        lengths,
        arrdict,
        cpucost_halo=cpucost_halo,
        nchunks_min=nchunks_min,
        chunksize_bytes=chunksize_bytes,
        entry_nbytes_in=entry_nbytes_in,
        nmax=nmax,
        idxlist=idxlist,
    )
register_field(parttype, name=None, construct=False)

Register a field.

Parameters:

Name Type Description Default
parttype str

name of particle type

required
name str

name of field

None
construct bool

construct field immediately

False

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def register_field(self, parttype: str, name: str = None, construct: bool = False):
    """
    Register a field.
    Parameters
    ----------
    parttype: str
        name of particle type
    name: str
        name of field
    construct: bool
        construct field immediately

    Returns
    -------
    None
    """
    num = part_type_num(parttype)
    if construct:  # TODO: introduce (immediate) construct option later
        raise NotImplementedError
    if num == -1:  # TODO: all particle species
        key = "all"
        raise NotImplementedError
    elif isinstance(num, int):
        key = "PartType" + str(num)
    else:
        key = parttype
    return super().register_field(key, name=name)
return_data()

Return data object of this snapshot.

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
217
218
219
220
221
222
223
224
225
226
@ArepoSelector()
def return_data(self) -> FieldContainer:
    """
    Return data object of this snapshot.

    Returns
    -------
    None
    """
    return super().return_data()
validate_path(path, *args, **kwargs) classmethod

Validate a path to use for instantiation of this class.

Parameters:

Name Type Description Default
path Union[str, PathLike]
required
args
()
kwargs
{}

Returns:

Type Description
CandidateStatus
Source code in src/scida/customs/arepo/dataset.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, **kwargs
) -> CandidateStatus:
    """
    Validate a path to use for instantiation of this class.

    Parameters
    ----------
    path: str or pathlib.Path
    args:
    kwargs:

    Returns
    -------
    CandidateStatus
    """
    valid = super().validate_path(path, *args, **kwargs)
    if valid.value > CandidateStatus.MAYBE.value:
        valid = CandidateStatus.MAYBE
    else:
        return valid
    # Arepo has no dedicated attribute to identify such runs.
    # lets just query a bunch of attributes that are present for arepo runs
    metadata_raw = load_metadata(path, **kwargs)
    matchingattrs = True
    matchingattrs &= "Git_commit" in metadata_raw["/Header"]
    # not existent for any arepo run?
    matchingattrs &= "Compactify_Version" not in metadata_raw["/Header"]

    if matchingattrs:
        valid = CandidateStatus.MAYBE

    return valid
ChainOps

Chain operations together.

Source code in src/scida/customs/arepo/dataset.py
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
class ChainOps:
    """
    Chain operations together.
    """

    def __init__(self, *funcs):
        """
        Initialize a ChainOps object.

        Parameters
        ----------
        funcs: List[function]
            Functions to chain together.
        """
        self.funcs = funcs
        self.kwargs = get_kwargs(
            funcs[-1]
        )  # so we can pass info from kwargs to map_halo_operation
        if self.kwargs.get("dtype") is None:
            self.kwargs["dtype"] = float

        def chained_call(*args):
            cf = None
            for i, f in enumerate(funcs):
                # first chain element can be multiple fields. treat separately
                if i == 0:
                    cf = f(*args)
                else:
                    cf = f(cf)
            return cf

        self.call = chained_call

    def __call__(self, *args, **kwargs):
        return self.call(*args, **kwargs)
__init__(*funcs)

Initialize a ChainOps object.

Parameters:

Name Type Description Default
funcs

Functions to chain together.

()
Source code in src/scida/customs/arepo/dataset.py
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def __init__(self, *funcs):
    """
    Initialize a ChainOps object.

    Parameters
    ----------
    funcs: List[function]
        Functions to chain together.
    """
    self.funcs = funcs
    self.kwargs = get_kwargs(
        funcs[-1]
    )  # so we can pass info from kwargs to map_halo_operation
    if self.kwargs.get("dtype") is None:
        self.kwargs["dtype"] = float

    def chained_call(*args):
        cf = None
        for i, f in enumerate(funcs):
            # first chain element can be multiple fields. treat separately
            if i == 0:
                cf = f(*args)
            else:
                cf = f(cf)
        return cf

    self.call = chained_call
GroupAwareOperation

Class for applying operations to groups.

Source code in src/scida/customs/arepo/dataset.py
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
class GroupAwareOperation:
    """
    Class for applying operations to groups.
    """

    opfuncs = dict(min=np.min, max=np.max, sum=np.sum, half=lambda x: x[::2])
    finalops = {"min", "max", "sum"}
    __slots__ = (
        "arrs",
        "ops",
        "offsets",
        "lengths",
        "final",
        "inputfields",
        "opfuncs_custom",
    )

    def __init__(
        self,
        offsets: NDArray,
        lengths: NDArray,
        arrs: Dict[str, da.Array],
        ops=None,
        inputfields=None,
    ):
        self.offsets = offsets
        self.lengths = lengths
        self.arrs = arrs
        self.opfuncs_custom = {}
        self.final = False
        self.inputfields = inputfields
        if ops is None:
            self.ops = []
        else:
            self.ops = ops

    def chain(self, add_op=None, final=False):
        """
        Chain another operation to this one.

        Parameters
        ----------
        add_op: str or function
            Operation to add. Can be a string (e.g. "min", "max", "sum") or a function.
        final: bool
            Whether this is the final operation in the chain.

        Returns
        -------
        GroupAwareOperation
        """
        if self.final:
            raise ValueError("Cannot chain any additional operation.")
        c = copy.copy(self)
        c.final = final
        c.opfuncs_custom = self.opfuncs_custom
        if add_op is not None:
            if isinstance(add_op, str):
                c.ops.append(add_op)
            elif callable(add_op):
                name = "custom" + str(len(self.opfuncs_custom) + 1)
                c.opfuncs_custom[name] = add_op
                c.ops.append(name)
            else:
                raise ValueError("Unknown operation of type '%s'" % str(type(add_op)))
        return c

    def min(self, field=None):
        """
        Get the minimum value for each group member.
        """
        if field is not None:
            if self.inputfields is not None:
                raise ValueError("Cannot change input field anymore.")
            self.inputfields = [field]
        return self.chain(add_op="min", final=True)

    def max(self, field=None):
        """
        Get the maximum value for each group member.
        """
        if field is not None:
            if self.inputfields is not None:
                raise ValueError("Cannot change input field anymore.")
            self.inputfields = [field]
        return self.chain(add_op="max", final=True)

    def sum(self, field=None):
        """
        Sum the values for each group member.
        """
        if field is not None:
            if self.inputfields is not None:
                raise ValueError("Cannot change input field anymore.")
            self.inputfields = [field]
        return self.chain(add_op="sum", final=True)

    def half(self):
        """
        Half the number of particles in each group member. For testing purposes.

        Returns
        -------
        GroupAwareOperation
        """
        return self.chain(add_op="half", final=False)

    def apply(self, func, final=False):
        """
        Apply a passed function.

        Parameters
        ----------
        func: function
            Function to apply.
        final: bool
            Whether this is the final operation in the chain.

        Returns
        -------
        GroupAwareOperation
        """
        return self.chain(add_op=func, final=final)

    def __copy__(self):
        # overwrite method so that copy holds a new ops list.
        c = type(self)(
            self.offsets,
            self.lengths,
            self.arrs,
            ops=list(self.ops),
            inputfields=self.inputfields,
        )
        return c

    def evaluate(self, nmax=None, idxlist=None, compute=True):
        """
        Evaluate the operation.

        Parameters
        ----------
        nmax: Optional[int]
            Maximum number of halos to process.
        idxlist: Optional[np.ndarray]
            List of halo indices to process. If not provided, (and nmax not set) all halos are processed.
        compute: bool
            Whether to compute the result immediately or return a dask object to compute later.

        Returns
        -------

        """
        # TODO: figure out return type
        # final operations: those that can only be at end of chain
        # intermediate operations: those that can only be prior to end of chain
        funcdict = dict()
        funcdict.update(**self.opfuncs)
        funcdict.update(**self.opfuncs_custom)

        func = ChainOps(*[funcdict[k] for k in self.ops])

        fieldnames = list(self.arrs.keys())
        if self.inputfields is None:
            opname = self.ops[0]
            if opname.startswith("custom"):
                dfltkwargs = get_kwargs(self.opfuncs_custom[opname])
                fieldnames = dfltkwargs.get("fieldnames", None)
                if isinstance(fieldnames, str):
                    fieldnames = [fieldnames]
                if fieldnames is None:
                    raise ValueError(
                        "Either pass fields to grouped(fields=...) "
                        "or specify fieldnames=... in applied func."
                    )
            else:
                raise ValueError(
                    "Specify field to operate on in operation or grouped()."
                )

        res = map_group_operation(
            func,
            self.offsets,
            self.lengths,
            self.arrs,
            fieldnames=fieldnames,
            nmax=nmax,
            idxlist=idxlist,
        )
        if compute:
            res = res.compute()
        return res
apply(func, final=False)

Apply a passed function.

Parameters:

Name Type Description Default
func

Function to apply.

required
final

Whether this is the final operation in the chain.

False

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
def apply(self, func, final=False):
    """
    Apply a passed function.

    Parameters
    ----------
    func: function
        Function to apply.
    final: bool
        Whether this is the final operation in the chain.

    Returns
    -------
    GroupAwareOperation
    """
    return self.chain(add_op=func, final=final)
chain(add_op=None, final=False)

Chain another operation to this one.

Parameters:

Name Type Description Default
add_op

Operation to add. Can be a string (e.g. "min", "max", "sum") or a function.

None
final

Whether this is the final operation in the chain.

False

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
def chain(self, add_op=None, final=False):
    """
    Chain another operation to this one.

    Parameters
    ----------
    add_op: str or function
        Operation to add. Can be a string (e.g. "min", "max", "sum") or a function.
    final: bool
        Whether this is the final operation in the chain.

    Returns
    -------
    GroupAwareOperation
    """
    if self.final:
        raise ValueError("Cannot chain any additional operation.")
    c = copy.copy(self)
    c.final = final
    c.opfuncs_custom = self.opfuncs_custom
    if add_op is not None:
        if isinstance(add_op, str):
            c.ops.append(add_op)
        elif callable(add_op):
            name = "custom" + str(len(self.opfuncs_custom) + 1)
            c.opfuncs_custom[name] = add_op
            c.ops.append(name)
        else:
            raise ValueError("Unknown operation of type '%s'" % str(type(add_op)))
    return c
evaluate(nmax=None, idxlist=None, compute=True)

Evaluate the operation.

Parameters:

Name Type Description Default
nmax

Maximum number of halos to process.

None
idxlist

List of halo indices to process. If not provided, (and nmax not set) all halos are processed.

None
compute

Whether to compute the result immediately or return a dask object to compute later.

True
Source code in src/scida/customs/arepo/dataset.py
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def evaluate(self, nmax=None, idxlist=None, compute=True):
    """
    Evaluate the operation.

    Parameters
    ----------
    nmax: Optional[int]
        Maximum number of halos to process.
    idxlist: Optional[np.ndarray]
        List of halo indices to process. If not provided, (and nmax not set) all halos are processed.
    compute: bool
        Whether to compute the result immediately or return a dask object to compute later.

    Returns
    -------

    """
    # TODO: figure out return type
    # final operations: those that can only be at end of chain
    # intermediate operations: those that can only be prior to end of chain
    funcdict = dict()
    funcdict.update(**self.opfuncs)
    funcdict.update(**self.opfuncs_custom)

    func = ChainOps(*[funcdict[k] for k in self.ops])

    fieldnames = list(self.arrs.keys())
    if self.inputfields is None:
        opname = self.ops[0]
        if opname.startswith("custom"):
            dfltkwargs = get_kwargs(self.opfuncs_custom[opname])
            fieldnames = dfltkwargs.get("fieldnames", None)
            if isinstance(fieldnames, str):
                fieldnames = [fieldnames]
            if fieldnames is None:
                raise ValueError(
                    "Either pass fields to grouped(fields=...) "
                    "or specify fieldnames=... in applied func."
                )
        else:
            raise ValueError(
                "Specify field to operate on in operation or grouped()."
            )

    res = map_group_operation(
        func,
        self.offsets,
        self.lengths,
        self.arrs,
        fieldnames=fieldnames,
        nmax=nmax,
        idxlist=idxlist,
    )
    if compute:
        res = res.compute()
    return res
half()

Half the number of particles in each group member. For testing purposes.

Returns:

Type Description
GroupAwareOperation
Source code in src/scida/customs/arepo/dataset.py
861
862
863
864
865
866
867
868
869
def half(self):
    """
    Half the number of particles in each group member. For testing purposes.

    Returns
    -------
    GroupAwareOperation
    """
    return self.chain(add_op="half", final=False)
max(field=None)

Get the maximum value for each group member.

Source code in src/scida/customs/arepo/dataset.py
841
842
843
844
845
846
847
848
849
def max(self, field=None):
    """
    Get the maximum value for each group member.
    """
    if field is not None:
        if self.inputfields is not None:
            raise ValueError("Cannot change input field anymore.")
        self.inputfields = [field]
    return self.chain(add_op="max", final=True)
min(field=None)

Get the minimum value for each group member.

Source code in src/scida/customs/arepo/dataset.py
831
832
833
834
835
836
837
838
839
def min(self, field=None):
    """
    Get the minimum value for each group member.
    """
    if field is not None:
        if self.inputfields is not None:
            raise ValueError("Cannot change input field anymore.")
        self.inputfields = [field]
    return self.chain(add_op="min", final=True)
sum(field=None)

Sum the values for each group member.

Source code in src/scida/customs/arepo/dataset.py
851
852
853
854
855
856
857
858
859
def sum(self, field=None):
    """
    Sum the values for each group member.
    """
    if field is not None:
        if self.inputfields is not None:
            raise ValueError("Cannot change input field anymore.")
        self.inputfields = [field]
    return self.chain(add_op="sum", final=True)
Temperature(arrs, ureg=None, **kwargs)

Compute gas temperature given (ElectronAbundance,InternalEnergy) in [K].

Source code in src/scida/customs/arepo/dataset.py
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
@fielddefs.register_field("PartType0")
def Temperature(arrs, ureg=None, **kwargs):
    """Compute gas temperature given (ElectronAbundance,InternalEnergy) in [K]."""
    xh = 0.76
    gamma = 5.0 / 3.0

    m_p = 1.672622e-24  # proton mass [g]
    k_B = 1.380650e-16  # boltzmann constant [erg/K]

    UnitEnergy_over_UnitMass = (
        1e10  # standard unit system (TODO: can obtain from snapshot)
    )
    f = UnitEnergy_over_UnitMass
    if ureg is not None:
        f = 1.0
        m_p = m_p * ureg.g
        k_B = k_B * ureg.erg / ureg.K

    xe = arrs["ElectronAbundance"]
    u_internal = arrs["InternalEnergy"]

    mu = 4 / (1 + 3 * xh + 4 * xh * xe) * m_p
    temp = f * (gamma - 1.0) * u_internal / k_B * mu

    return temp
compute_haloindex(gidx, halocelloffsets, *args, index_unbound=None)

Computes the halo index for each particle with dask.

Source code in src/scida/customs/arepo/dataset.py
1061
1062
1063
1064
1065
1066
1067
1068
1069
def compute_haloindex(gidx, halocelloffsets, *args, index_unbound=None):
    """Computes the halo index for each particle with dask."""
    return da.map_blocks(
        get_hidx_daskwrap,
        gidx,
        halocelloffsets,
        index_unbound=index_unbound,
        meta=np.array((), dtype=np.int64),
    )
compute_haloquantity(gidx, halocelloffsets, hvals, *args)

Computes a halo quantity for each particle with dask.

Source code in src/scida/customs/arepo/dataset.py
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
def compute_haloquantity(gidx, halocelloffsets, hvals, *args):
    """Computes a halo quantity for each particle with dask."""
    units = None
    if hasattr(hvals, "units"):
        units = hvals.units
    res = map_blocks(
        get_haloquantity_daskwrap,
        gidx,
        halocelloffsets,
        hvals,
        meta=np.array((), dtype=hvals.dtype),
        output_units=units,
    )
    return res
compute_localsubhaloindex(gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None)

Compute the local subhalo index for each particle with dask. The local subhalo index is the index of the subhalo within each halo, starting at 0 for the central subhalo.

Parameters:

Name Type Description Default
gidx
required
halocelloffsets
required
shnumber
required
shcounts
required
shcellcounts
required
index_unbound
None
Source code in src/scida/customs/arepo/dataset.py
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
def compute_localsubhaloindex(
    gidx, halocelloffsets, shnumber, shcounts, shcellcounts, index_unbound=None
) -> da.Array:
    """
    Compute the local subhalo index for each particle with dask.
    The local subhalo index is the index of the subhalo within each halo,
    starting at 0 for the central subhalo.

    Parameters
    ----------
    gidx
    halocelloffsets
    shnumber
    shcounts
    shcellcounts
    index_unbound

    Returns
    -------

    """
    res = da.map_blocks(
        get_local_shidx_daskwrap,
        gidx,
        halocelloffsets,
        shnumber,
        shcounts,
        shcellcounts,
        index_unbound=index_unbound,
        meta=np.array((), dtype=np.int64),
    )
    return res
get_hidx(gidx_start, gidx_count, celloffsets, index_unbound=None)

Get halo index of a given cell

Parameters:

Name Type Description Default
gidx_start

The first unique integer ID for the first particle

required
gidx_count

The amount of halo indices we are querying after "gidx_start"

required
celloffsets array

An array holding the starting cell offset for each halo. Needs to include the offset after the last halo. The required shape is thus (Nhalo+1,).

required
index_unbound integer

The index to use for unbound particles. If None, the maximum integer value of the dtype is used.

None
Source code in src/scida/customs/arepo/dataset.py
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
@jit(nopython=True)
def get_hidx(gidx_start, gidx_count, celloffsets, index_unbound=None):
    """Get halo index of a given cell

    Parameters
    ----------
    gidx_start: integer
        The first unique integer ID for the first particle
    gidx_count: integer
        The amount of halo indices we are querying after "gidx_start"
    celloffsets : array
        An array holding the starting cell offset for each halo. Needs to include the
        offset after the last halo. The required shape is thus (Nhalo+1,).
    index_unbound : integer, optional
        The index to use for unbound particles. If None, the maximum integer value
        of the dtype is used.
    """
    dtype = np.int64
    if index_unbound is None:
        index_unbound = np.iinfo(dtype).max
    res = index_unbound * np.ones(gidx_count, dtype=dtype)
    # find initial celloffset
    hidx_idx = np.searchsorted(celloffsets, gidx_start, side="right") - 1
    if hidx_idx + 1 >= celloffsets.shape[0]:
        # we are done. Already out of scope of lookup => all unbound gas.
        return res
    celloffset = celloffsets[hidx_idx + 1]
    endid = celloffset - gidx_start
    startid = 0

    # Now iterate through list.
    while startid < gidx_count:
        res[startid:endid] = hidx_idx
        hidx_idx += 1
        startid = endid
        if hidx_idx >= celloffsets.shape[0] - 1:
            break
        count = celloffsets[hidx_idx + 1] - celloffsets[hidx_idx]
        endid = startid + count
    return res
get_localshidx(gidx_start, gidx_count, celloffsets, shnumber, shcounts, shcellcounts, index_unbound=None)

Get the local subhalo index for each particle. This is the subhalo index within each halo group. Particles belonging to the central galaxies will have index 0, particles belonging to the first satellite will have index 1, etc.

Parameters:

Name Type Description Default
gidx_start int
required
gidx_count int
required
celloffsets NDArray[int64]
required
shnumber
required
shcounts
required
shcellcounts
required
index_unbound

The index to use for unbound particles. If None, the maximum integer value of the dtype is used.

None

Returns:

Type Description
ndarray
Source code in src/scida/customs/arepo/dataset.py
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
@jit(nopython=True)
def get_localshidx(
    gidx_start: int,
    gidx_count: int,
    celloffsets: NDArray[np.int64],
    shnumber,
    shcounts,
    shcellcounts,
    index_unbound=None,
):
    """
    Get the local subhalo index for each particle. This is the subhalo index within each
    halo group. Particles belonging to the central galaxies will have index 0, particles
    belonging to the first satellite will have index 1, etc.
    Parameters
    ----------
    gidx_start
    gidx_count
    celloffsets
    shnumber
    shcounts
    shcellcounts
    index_unbound: integer, optional
        The index to use for unbound particles. If None, the maximum integer value
        of the dtype is used.

    Returns
    -------
    np.ndarray
    """
    dtype = np.int32
    if index_unbound is None:
        index_unbound = np.iinfo(dtype).max
    res = index_unbound * np.ones(gidx_count, dtype=dtype)  # fuzz has negative index.

    # find initial Group we are in
    hidx_start_idx = np.searchsorted(celloffsets, gidx_start, side="right") - 1
    if hidx_start_idx + 1 >= celloffsets.shape[0]:
        # we are done. Already out of scope of lookup => all unbound gas.
        return res
    celloffset = celloffsets[hidx_start_idx + 1]
    endid = celloffset - gidx_start
    startid = 0

    # find initial subhalo we are in
    hidx = hidx_start_idx
    shcumsum = np.zeros(shcounts[hidx] + 1, dtype=np.int64)
    shcumsum[1:] = np.cumsum(
        shcellcounts[shnumber[hidx] : shnumber[hidx] + shcounts[hidx]]
    )  # collect halo's subhalo offsets
    shcumsum += celloffsets[hidx_start_idx]
    sidx_start_idx: int = int(np.searchsorted(shcumsum, gidx_start, side="right") - 1)
    if sidx_start_idx < shcounts[hidx]:
        endid = shcumsum[sidx_start_idx + 1] - gidx_start

    # Now iterate through list.
    cont = True
    while cont and (startid < gidx_count):
        res[startid:endid] = (
            sidx_start_idx if sidx_start_idx + 1 < shcumsum.shape[0] else -1
        )
        sidx_start_idx += 1
        if sidx_start_idx < shcounts[hidx_start_idx]:
            # we prepare to fill the next available subhalo for current halo
            count = shcumsum[sidx_start_idx + 1] - shcumsum[sidx_start_idx]
            startid = endid
        else:
            # we need to find the next halo to start filling its subhalos
            dbgcount = 0
            while dbgcount < 100:  # find next halo with >0 subhalos
                hidx_start_idx += 1
                if hidx_start_idx >= shcounts.shape[0]:
                    cont = False
                    break
                if shcounts[hidx_start_idx] > 0:
                    break
                dbgcount += 1
            hidx = hidx_start_idx
            if hidx_start_idx >= celloffsets.shape[0] - 1:
                startid = gidx_count
            else:
                count = celloffsets[hidx_start_idx + 1] - celloffsets[hidx_start_idx]
                if hidx < shcounts.shape[0]:
                    shcumsum = np.zeros(shcounts[hidx] + 1, dtype=np.int64)
                    shcumsum[1:] = np.cumsum(
                        shcellcounts[shnumber[hidx] : shnumber[hidx] + shcounts[hidx]]
                    )
                    shcumsum += celloffsets[hidx_start_idx]
                    sidx_start_idx = 0
                    if sidx_start_idx < shcounts[hidx]:
                        count = shcumsum[sidx_start_idx + 1] - shcumsum[sidx_start_idx]
                    startid = celloffsets[hidx_start_idx] - gidx_start
        endid = startid + count
    return res
get_shcounts_shcells(SubhaloGrNr, hlength)

Returns the id of the first subhalo and count of subhalos per halo.

Parameters:

Name Type Description Default
SubhaloGrNr

The group identifier that each subhalo belongs to respectively

required
hlength

The number of halos in the snapshot

required

Returns:

Name Type Description
shcounts ndarray

The number of subhalos per halo

shnumber ndarray

The index of the first subhalo per halo

Source code in src/scida/customs/arepo/dataset.py
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
@jit(nopython=True)
def get_shcounts_shcells(SubhaloGrNr, hlength):
    """
    Returns the id of the first subhalo and count of subhalos per halo.

    Parameters
    ----------
    SubhaloGrNr: np.ndarray
        The group identifier that each subhalo belongs to respectively
    hlength: int
        The number of halos in the snapshot

    Returns
    -------
    shcounts: np.ndarray
        The number of subhalos per halo
    shnumber: np.ndarray
        The index of the first subhalo per halo
    """
    shcounts = np.zeros(hlength, dtype=np.int32)  # number of subhalos per halo
    shnumber = np.zeros(hlength, dtype=np.int32)  # index of first subhalo per halo
    i = 0
    hid_old = 0
    while i < SubhaloGrNr.shape[0]:
        hid = SubhaloGrNr[i]
        if hid == hid_old:
            shcounts[hid] += 1
        else:
            shnumber[hid] = i
            shcounts[hid] += 1
            hid_old = hid
        i += 1
    return shcounts, shnumber
map_group_operation(func, offsets, lengths, arrdict, cpucost_halo=10000.0, nchunks_min=None, chunksize_bytes=None, entry_nbytes_in=4, fieldnames=None, nmax=None, idxlist=None)

Map a function to all halos in a halo catalog.

Parameters:

Name Type Description Default
idxlist Optional[ndarray]

Only process the halos with these indices.

None
nmax Optional[int]

Only process the first nmax halos.

None
func
required
offsets

Offset of each group in the particle catalog.

required
lengths

Number of particles per halo.

required
arrdict
required
cpucost_halo
10000.0
nchunks_min Optional[int]

Lower bound on the number of halos per chunk.

None
chunksize_bytes Optional[int]
None
entry_nbytes_in Optional[int]
4
fieldnames Optional[List[str]]
None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
def map_group_operation(
    func,
    offsets,
    lengths,
    arrdict,
    cpucost_halo=1e4,
    nchunks_min: Optional[int] = None,
    chunksize_bytes: Optional[int] = None,
    entry_nbytes_in: Optional[int] = 4,
    fieldnames: Optional[List[str]] = None,
    nmax: Optional[int] = None,
    idxlist: Optional[np.ndarray] = None,
) -> da.Array:
    """
    Map a function to all halos in a halo catalog.
    Parameters
    ----------
    idxlist: Optional[np.ndarray]
        Only process the halos with these indices.
    nmax: Optional[int]
        Only process the first nmax halos.
    func
    offsets: np.ndarray
        Offset of each group in the particle catalog.
    lengths: np.ndarray
        Number of particles per halo.
    arrdict
    cpucost_halo
    nchunks_min: Optional[int]
        Lower bound on the number of halos per chunk.
    chunksize_bytes
    entry_nbytes_in
    fieldnames

    Returns
    -------
    None
    """
    if isinstance(func, ChainOps):
        dfltkwargs = func.kwargs
    else:
        dfltkwargs = get_kwargs(func)
    if fieldnames is None:
        fieldnames = dfltkwargs.get("fieldnames", None)
    if fieldnames is None:
        fieldnames = get_args(func)
    units = dfltkwargs.get("units", None)
    shape = dfltkwargs.get("shape", None)
    dtype = dfltkwargs.get("dtype", "float64")
    fill_value = dfltkwargs.get("fill_value", 0)

    if idxlist is not None and nmax is not None:
        raise ValueError("Cannot specify both idxlist and nmax.")

    lengths_all = lengths
    offsets_all = offsets
    if len(lengths) == len(offsets):
        # the offsets array here is one longer here, holding the total number of particles in the last halo.
        offsets_all = np.concatenate([offsets_all, [offsets_all[-1] + lengths[-1]]])

    if nmax is not None:
        lengths = lengths[:nmax]
        offsets = offsets[:nmax]

    if idxlist is not None:
        # make sure idxlist is sorted and unique
        if not np.all(np.diff(idxlist) > 0):
            raise ValueError("idxlist must be sorted and unique.")
        # make sure idxlist is within range
        if np.min(idxlist) < 0 or np.max(idxlist) >= lengths.shape[0]:
            raise ValueError(
                "idxlist elements must be in [%i, %i), but covers range [%i, %i]."
                % (0, lengths.shape[0], np.min(idxlist), np.max(idxlist))
            )
        offsets = offsets[idxlist]
        lengths = lengths[idxlist]

    if len(lengths) == len(offsets):
        # the offsets array here is one longer here, holding the total number of particles in the last halo.
        offsets = np.concatenate([offsets, [offsets[-1] + lengths[-1]]])

    # shape/units inference
    infer_shape = shape is None or (isinstance(shape, str) and shape == "auto")
    infer_units = units is None
    infer = infer_shape or infer_units
    if infer:
        # attempt to determine shape.
        if infer_shape:
            log.debug(
                "No shape specified. Attempting to determine shape of func output."
            )
        if infer_units:
            log.debug(
                "No units specified. Attempting to determine units of func output."
            )
        arrs = [arrdict[f][:1].compute() for f in fieldnames]
        # remove units if present
        # arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
        # arrs = [arr.magnitude for arr in arrs]
        dummyres = None
        try:
            dummyres = func(*arrs)
        except Exception as e:  # noqa
            log.warning("Exception during shape/unit inference: %s." % str(e))
        if dummyres is not None:
            if infer_units and hasattr(dummyres, "units"):
                units = dummyres.units
            log.debug("Shape inference: %s." % str(shape))
        if infer_units and dummyres is None:
            units_present = any([hasattr(arr, "units") for arr in arrs])
            if units_present:
                log.warning("Exception during unit inference. Assuming no units.")
        if dummyres is None and infer_shape:
            # due to https://github.com/hgrecco/pint/issues/1037 innocent np.array operations on unit scalars can fail.
            # we can still attempt to infer shape by removing units prior to calling func.
            arrs = [arr.magnitude if hasattr(arr, "magnitude") else arr for arr in arrs]
            try:
                dummyres = func(*arrs)
            except Exception as e:  # noqa
                # no more logging needed here
                pass
        if dummyres is not None and infer_shape:
            if np.isscalar(dummyres):
                shape = (1,)
            else:
                shape = dummyres.shape
        if infer_shape and dummyres is None and shape is None:
            log.warning("Exception during shape inference. Using shape (1,).")
            shape = ()
    # unit inference

    # Determine chunkedges automatically
    # TODO: very messy and inefficient routine. improve some time.
    # TODO: Set entry_bytes_out
    nbytes_dtype_out = 4  # TODO: hardcode 4 byte output dtype as estimate for now
    entry_nbytes_out = nbytes_dtype_out * np.product(shape)

    # list_chunkedges refers to bounds of index intervals to be processed together
    # if idxlist is specified, then these indices do not have to refer to group indices.
    # if idxlist is given, we enforce that particle data is contiguous
    # by putting each idx from idxlist into its own chunk.
    # in the future, we should optimize this
    if idxlist is not None:
        list_chunkedges = [[idx, idx + 1] for idx in np.arange(len(idxlist))]
    else:
        list_chunkedges = map_group_operation_get_chunkedges(
            lengths,
            entry_nbytes_in,
            entry_nbytes_out,
            cpucost_halo=cpucost_halo,
            nchunks_min=nchunks_min,
            chunksize_bytes=chunksize_bytes,
        )

    minentry = offsets[0]
    maxentry = offsets[-1]  # the last particle that needs to be processed

    # chunks specify the number of groups in each chunk
    chunks = [tuple(np.diff(list_chunkedges, axis=1).flatten())]
    # need to add chunk information for additional output axes if needed
    new_axis = None
    if isinstance(shape, tuple) and shape != (1,):
        chunks += [(s,) for s in shape]
        new_axis = np.arange(1, len(shape) + 1).tolist()

    # slcoffsets = [offsets[chunkedge[0]] for chunkedge in list_chunkedges]
    # the actual length of relevant data in each chunk
    slclengths = [
        offsets[chunkedge[1]] - offsets[chunkedge[0]] for chunkedge in list_chunkedges
    ]
    if idxlist is not None:
        # the chunk length to be fed into map_blocks
        tmplist = np.concatenate([idxlist, [len(lengths_all)]])
        slclengths_map = [
            offsets_all[tmplist[chunkedge[1]]] - offsets_all[tmplist[chunkedge[0]]]
            for chunkedge in list_chunkedges
        ]
        slcoffsets_map = [
            offsets_all[tmplist[chunkedge[0]]] for chunkedge in list_chunkedges
        ]
        slclengths_map[0] = slcoffsets_map[0]
        slcoffsets_map[0] = 0
    else:
        slclengths_map = slclengths

    slcs = [slice(chunkedge[0], chunkedge[1]) for chunkedge in list_chunkedges]
    offsets_in_chunks = [offsets[slc] - offsets[slc.start] for slc in slcs]
    lengths_in_chunks = [lengths[slc] for slc in slcs]
    d_oic = delayed(offsets_in_chunks)
    d_hic = delayed(lengths_in_chunks)

    arrs = [arrdict[f][minentry:maxentry] for f in fieldnames]
    for i, arr in enumerate(arrs):
        arrchunks = ((tuple(slclengths)),)
        if len(arr.shape) > 1:
            arrchunks = arrchunks + (arr.shape[1:],)
        arrs[i] = arr.rechunk(chunks=arrchunks)
    arrdims = np.array([len(arr.shape) for arr in arrs])

    assert np.all(arrdims == arrdims[0])  # Cannot handle different input dims for now

    drop_axis = []
    if arrdims[0] > 1:
        drop_axis = np.arange(1, arrdims[0])

    if dtype is None:
        raise ValueError(
            "dtype must be specified, dask will not be able to automatically determine this here."
        )

    calc = map_blocks(
        wrap_func_scalar,
        func,
        d_oic,
        d_hic,
        *arrs,
        dtype=dtype,
        chunks=chunks,
        new_axis=new_axis,
        drop_axis=drop_axis,
        func_output_shape=shape,
        func_output_dtype=dtype,
        fill_value=fill_value,
        output_units=units,
    )

    return calc
map_group_operation_get_chunkedges(lengths, entry_nbytes_in, entry_nbytes_out, cpucost_halo=1.0, nchunks_min=None, chunksize_bytes=None)

Compute the chunking of a halo operation.

Parameters:

Name Type Description Default
lengths

The number of particles per halo.

required
entry_nbytes_in
required
entry_nbytes_out
required
cpucost_halo
1.0
nchunks_min
None
chunksize_bytes
None

Returns:

Type Description
None
Source code in src/scida/customs/arepo/dataset.py
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
def map_group_operation_get_chunkedges(
    lengths,
    entry_nbytes_in,
    entry_nbytes_out,
    cpucost_halo=1.0,
    nchunks_min=None,
    chunksize_bytes=None,
):
    """
    Compute the chunking of a halo operation.

    Parameters
    ----------
    lengths: np.ndarray
        The number of particles per halo.
    entry_nbytes_in
    entry_nbytes_out
    cpucost_halo
    nchunks_min
    chunksize_bytes

    Returns
    -------
    None
    """
    cpucost_particle = 1.0  # we only care about ratio, so keep particle cost fixed.
    cost = cpucost_particle * lengths + cpucost_halo
    sumcost = cost.cumsum()

    # let's allow a maximal chunksize of 16 times the dask default setting for an individual array [here: multiple]
    if chunksize_bytes is None:
        chunksize_bytes = 16 * parse_humansize(dask.config.get("array.chunk-size"))
    cost_memory = entry_nbytes_in * lengths + entry_nbytes_out

    if not np.max(cost_memory) < chunksize_bytes:
        raise ValueError(
            "Some halo requires more memory than allowed (%i allowed, %i requested). Consider overriding "
            "chunksize_bytes." % (chunksize_bytes, np.max(cost_memory))
        )

    nchunks = int(np.ceil(np.sum(cost_memory) / chunksize_bytes))
    nchunks = int(np.ceil(1.3 * nchunks))  # fudge factor
    if nchunks_min is not None:
        nchunks = max(nchunks_min, nchunks)
    targetcost = sumcost[-1] / nchunks  # chunk target cost = total cost / nchunks

    arr = np.diff(sumcost % targetcost)  # find whenever exceeding modulo target cost
    idx = [0] + list(np.where(arr < 0)[0] + 1)
    if idx[-1] != sumcost.shape[0]:
        idx.append(sumcost.shape[0])
    list_chunkedges = []
    for i in range(len(idx) - 1):
        list_chunkedges.append([idx[i], idx[i + 1]])

    list_chunkedges = np.asarray(
        memorycost_limiter(cost_memory, cost, list_chunkedges, chunksize_bytes)
    )

    # make sure we did not lose any halos.
    assert np.all(
        ~(list_chunkedges.flatten()[2:-1:2] - list_chunkedges.flatten()[1:-1:2]).astype(
            bool
        )
    )
    return list_chunkedges
memorycost_limiter(cost_memory, cost_cpu, list_chunkedges, cost_memory_max)

If a chunk too memory expensive, split into equal cpu expense operations.

Source code in src/scida/customs/arepo/dataset.py
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
def memorycost_limiter(cost_memory, cost_cpu, list_chunkedges, cost_memory_max):
    """If a chunk too memory expensive, split into equal cpu expense operations."""
    list_chunkedges_new = []
    for chunkedges in list_chunkedges:
        slc = slice(*chunkedges)
        totcost_mem = np.sum(cost_memory[slc])
        list_chunkedges_new.append(chunkedges)
        if totcost_mem > cost_memory_max:
            sumcost = cost_cpu[slc].cumsum()
            sumcost /= sumcost[-1]
            idx = slc.start + np.argmin(np.abs(sumcost - 0.5))
            if idx == chunkedges[0]:
                idx += 1
            elif idx == chunkedges[-1]:
                idx -= 1
            chunkedges1 = [chunkedges[0], idx]
            chunkedges2 = [idx, chunkedges[1]]
            if idx == chunkedges[0] or idx == chunkedges[1]:
                raise ValueError("This should not happen.")
            list_chunkedges_new.pop()
            list_chunkedges_new += memorycost_limiter(
                cost_memory, cost_cpu, [chunkedges1], cost_memory_max
            )
            list_chunkedges_new += memorycost_limiter(
                cost_memory, cost_cpu, [chunkedges2], cost_memory_max
            )
    return list_chunkedges_new
wrap_func_scalar(func, offsets_in_chunks, lengths_in_chunks, *arrs, block_info=None, block_id=None, func_output_shape=(1), func_output_dtype='float64', fill_value=0)

Wrapper for applying a function to each halo in the passed chunk.

Parameters:

Name Type Description Default
func
required
offsets_in_chunks
required
lengths_in_chunks
required
arrs
()
block_info
None
block_id
None
func_output_shape
(1)
func_output_dtype
'float64'
fill_value
0
Source code in src/scida/customs/arepo/dataset.py
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
def wrap_func_scalar(
    func,
    offsets_in_chunks,
    lengths_in_chunks,
    *arrs,
    block_info=None,
    block_id=None,
    func_output_shape=(1,),
    func_output_dtype="float64",
    fill_value=0,
):
    """
    Wrapper for applying a function to each halo in the passed chunk.
    Parameters
    ----------
    func
    offsets_in_chunks
    lengths_in_chunks
    arrs
    block_info
    block_id
    func_output_shape
    func_output_dtype
    fill_value

    Returns
    -------

    """
    offsets = offsets_in_chunks[block_id[0]]
    lengths = lengths_in_chunks[block_id[0]]

    res = []
    for i, length in enumerate(lengths):
        o = offsets[i]
        if length == 0:
            res.append(fill_value * np.ones(func_output_shape, dtype=func_output_dtype))
            if func_output_shape == (1,):
                res[-1] = res[-1].item()
            continue
        arrchunks = [arr[o : o + length] for arr in arrs]
        res.append(func(*arrchunks))
    return np.array(res)

helpers

Helper functions for arepo snapshots/simulations.

grp_type_str(gtype)

Mapping between common group names and numeric group types.

Source code in src/scida/customs/arepo/helpers.py
 4
 5
 6
 7
 8
 9
10
def grp_type_str(gtype):
    """Mapping between common group names and numeric group types."""
    if str(gtype).lower() in ["group", "groups", "halo", "halos"]:
        return "halo"
    if str(gtype).lower() in ["subgroup", "subgroups", "subhalo", "subhalos"]:
        return "subhalo"
    raise ValueError("Unknown group type: %s" % gtype)
part_type_num(ptype)

Mapping between common particle names and numeric particle types.

Source code in src/scida/customs/arepo/helpers.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def part_type_num(ptype):
    """Mapping between common particle names and numeric particle types."""
    ptype = str(ptype).replace("PartType", "")
    if ptype.isdigit():
        return int(ptype)

    if str(ptype).lower() in ["gas", "cells"]:
        return 0
    if str(ptype).lower() in ["dm", "darkmatter"]:
        return 1
    if str(ptype).lower() in ["dmlowres"]:
        return 2  # only zoom simulations, not present in full periodic boxes
    if str(ptype).lower() in ["tracer", "tracers", "tracermc", "trmc"]:
        return 3
    if str(ptype).lower() in ["star", "stars", "stellar"]:
        return 4  # only those with GFM_StellarFormationTime>0
    if str(ptype).lower() in ["wind"]:
        return 4  # only those with GFM_StellarFormationTime<0
    if str(ptype).lower() in ["bh", "bhs", "blackhole", "blackholes", "black"]:
        return 5
    if str(ptype).lower() in ["all"]:
        return -1

selector

Selector for ArepoSnapshot

ArepoSelector

Bases: Selector

Selector for ArepoSnapshot. Can select for haloID, subhaloID, and unbound particles.

Source code in src/scida/customs/arepo/selector.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class ArepoSelector(Selector):
    """Selector for ArepoSnapshot.
    Can select for haloID, subhaloID, and unbound particles."""

    def __init__(self) -> None:
        """
        Initialize the selector.
        """
        super().__init__()
        self.keys = ["haloID", "subhaloID", "unbound"]

    def prepare(self, *args, **kwargs) -> None:
        if all([kwargs.get(k, None) is None for k in self.keys]):
            return  # no specific selection, thus just return
        snap: ArepoSnapshot = args[0]
        halo_id = kwargs.get("haloID", None)
        subhalo_id = kwargs.get("subhaloID", None)
        unbound = kwargs.get("unbound", None)

        if halo_id is not None and subhalo_id is not None:
            raise ValueError("Cannot select for haloID and subhaloID at the same time.")

        if unbound is True and (halo_id is not None or subhalo_id is not None):
            raise ValueError(
                "Cannot select haloID/subhaloID and unbound particles at the same time."
            )

        if snap.catalog is None:
            raise ValueError("Cannot select for haloID without catalog loaded.")

        # select for halo
        idx = subhalo_id if subhalo_id is not None else halo_id
        objtype = "subhalo" if subhalo_id is not None else "halo"
        if idx is not None:
            self.select_group(snap, idx, objtype=objtype)
        elif unbound is True:
            self.select_unbound(snap)

    def select_unbound(self, snap):
        """
        Select unbound particles.

        Parameters
        ----------
        snap: ArepoSnapshot

        Returns
        -------
        None
        """
        lengths = self.data_backup["Group"]["GroupLenType"][-1, :].compute()
        offsets = self.data_backup["Group"]["GroupOffsetsType"][-1, :].compute()
        # for unbound gas, we start after the last halo particles
        offsets = offsets + lengths
        for p in self.data_backup:
            splt = p.split("PartType")
            if len(splt) == 1:
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v
            else:
                pnum = int(splt[1])
                offset = offsets[pnum]
                if hasattr(offset, "magnitude"):  # hack for issue 59
                    offset = offset.magnitude
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v[offset:-1]
        snap.data = self.data

    def select_group(self, snap, idx, objtype="Group"):
        """
        Select particles for given group/subhalo index.

        Parameters
        ----------
        snap: ArepoSnapshot
        idx: int
        objtype: str

        Returns
        -------
        None
        """
        # TODO: test whether works for multiple groups via idx list
        objtype = grp_type_str(objtype)
        if objtype == "halo":
            lengths = self.data_backup["Group"]["GroupLenType"][idx, :].compute()
            offsets = self.data_backup["Group"]["GroupOffsetsType"][idx, :].compute()
        elif objtype == "subhalo":
            lengths = {i: snap.get_subhalolengths(i)[idx] for i in range(6)}
            offsets = {i: snap.get_subhalooffsets(i)[idx] for i in range(6)}
        else:
            raise ValueError("Unknown object type: %s" % objtype)

        for p in self.data_backup:
            splt = p.split("PartType")
            if len(splt) == 1:
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v
            else:
                pnum = int(splt[1])
                offset = offsets[pnum]
                length = lengths[pnum]
                if hasattr(offset, "magnitude"):  # hack for issue 59
                    offset = offset.magnitude
                if hasattr(length, "magnitude"):
                    length = length.magnitude
                for k, v in self.data_backup[p].items():
                    self.data[p][k] = v[offset : offset + length]
        snap.data = self.data
__init__()

Initialize the selector.

Source code in src/scida/customs/arepo/selector.py
18
19
20
21
22
23
def __init__(self) -> None:
    """
    Initialize the selector.
    """
    super().__init__()
    self.keys = ["haloID", "subhaloID", "unbound"]
select_group(snap, idx, objtype='Group')

Select particles for given group/subhalo index.

Parameters:

Name Type Description Default
snap
required
idx
required
objtype
'Group'

Returns:

Type Description
None
Source code in src/scida/customs/arepo/selector.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def select_group(self, snap, idx, objtype="Group"):
    """
    Select particles for given group/subhalo index.

    Parameters
    ----------
    snap: ArepoSnapshot
    idx: int
    objtype: str

    Returns
    -------
    None
    """
    # TODO: test whether works for multiple groups via idx list
    objtype = grp_type_str(objtype)
    if objtype == "halo":
        lengths = self.data_backup["Group"]["GroupLenType"][idx, :].compute()
        offsets = self.data_backup["Group"]["GroupOffsetsType"][idx, :].compute()
    elif objtype == "subhalo":
        lengths = {i: snap.get_subhalolengths(i)[idx] for i in range(6)}
        offsets = {i: snap.get_subhalooffsets(i)[idx] for i in range(6)}
    else:
        raise ValueError("Unknown object type: %s" % objtype)

    for p in self.data_backup:
        splt = p.split("PartType")
        if len(splt) == 1:
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v
        else:
            pnum = int(splt[1])
            offset = offsets[pnum]
            length = lengths[pnum]
            if hasattr(offset, "magnitude"):  # hack for issue 59
                offset = offset.magnitude
            if hasattr(length, "magnitude"):
                length = length.magnitude
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v[offset : offset + length]
    snap.data = self.data
select_unbound(snap)

Select unbound particles.

Parameters:

Name Type Description Default
snap
required

Returns:

Type Description
None
Source code in src/scida/customs/arepo/selector.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def select_unbound(self, snap):
    """
    Select unbound particles.

    Parameters
    ----------
    snap: ArepoSnapshot

    Returns
    -------
    None
    """
    lengths = self.data_backup["Group"]["GroupLenType"][-1, :].compute()
    offsets = self.data_backup["Group"]["GroupOffsetsType"][-1, :].compute()
    # for unbound gas, we start after the last halo particles
    offsets = offsets + lengths
    for p in self.data_backup:
        splt = p.split("PartType")
        if len(splt) == 1:
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v
        else:
            pnum = int(splt[1])
            offset = offsets[pnum]
            if hasattr(offset, "magnitude"):  # hack for issue 59
                offset = offset.magnitude
            for k, v in self.data_backup[p].items():
                self.data[p][k] = v[offset:-1]
    snap.data = self.data

series

Contains Series class for Arepo simulations.

ArepoSimulation

Bases: GadgetStyleSimulation

A series representing an Arepo simulation.

Source code in src/scida/customs/arepo/series.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class ArepoSimulation(GadgetStyleSimulation):
    """A series representing an Arepo simulation."""

    def __init__(self, path, lazy=True, async_caching=False, **interface_kwargs):
        """
        Initialize an ArepoSimulation object.

        Parameters
        ----------
        path: str
            Path to the simulation folder, should contain "output" folder.
        lazy: bool
            Whether to load data files lazily.
        interface_kwargs: dict
            Additional keyword arguments passed to the interface.
        """
        # choose parent folder as path if we are passed "output" dir
        p = pathlib.Path(path)
        if p.name == "output":
            path = str(p.parent)
        prefix_dict = dict(paths="snapdir", gpaths="group")
        arg_dict = dict(gpaths="catalog")
        super().__init__(
            path,
            prefix_dict=prefix_dict,
            arg_dict=arg_dict,
            lazy=lazy,
            **interface_kwargs
        )

    @classmethod
    def validate_path(cls, path, *args, **kwargs) -> CandidateStatus:
        """
        Validate a path as a candidate for this simulation class.

        Parameters
        ----------
        path: str
            Path to validate.
        args: list
            Additional positional arguments.
        kwargs:
            Additional keyword arguments.

        Returns
        -------
        CandidateStatus
            Whether the path is a candidate for this simulation class.
        """
        valid = CandidateStatus.NO
        if not os.path.isdir(path):
            return CandidateStatus.NO
        fns = os.listdir(path)
        if "gizmo_parameters.txt" in fns:
            return CandidateStatus.NO
        sprefixs = ["snapdir", "snapshot"]
        opath = path
        if "output" in fns:
            opath = join(path, "output")
        folders = os.listdir(opath)
        folders = [f for f in folders if os.path.isdir(join(opath, f))]
        if any([f.startswith(k) for f in folders for k in sprefixs]):
            valid = CandidateStatus.MAYBE
        return valid
__init__(path, lazy=True, async_caching=False, **interface_kwargs)

Initialize an ArepoSimulation object.

Parameters:

Name Type Description Default
path

Path to the simulation folder, should contain "output" folder.

required
lazy

Whether to load data files lazily.

True
interface_kwargs

Additional keyword arguments passed to the interface.

{}
Source code in src/scida/customs/arepo/series.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(self, path, lazy=True, async_caching=False, **interface_kwargs):
    """
    Initialize an ArepoSimulation object.

    Parameters
    ----------
    path: str
        Path to the simulation folder, should contain "output" folder.
    lazy: bool
        Whether to load data files lazily.
    interface_kwargs: dict
        Additional keyword arguments passed to the interface.
    """
    # choose parent folder as path if we are passed "output" dir
    p = pathlib.Path(path)
    if p.name == "output":
        path = str(p.parent)
    prefix_dict = dict(paths="snapdir", gpaths="group")
    arg_dict = dict(gpaths="catalog")
    super().__init__(
        path,
        prefix_dict=prefix_dict,
        arg_dict=arg_dict,
        lazy=lazy,
        **interface_kwargs
    )
validate_path(path, *args, **kwargs) classmethod

Validate a path as a candidate for this simulation class.

Parameters:

Name Type Description Default
path

Path to validate.

required
args

Additional positional arguments.

()
kwargs

Additional keyword arguments.

{}

Returns:

Type Description
CandidateStatus

Whether the path is a candidate for this simulation class.

Source code in src/scida/customs/arepo/series.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@classmethod
def validate_path(cls, path, *args, **kwargs) -> CandidateStatus:
    """
    Validate a path as a candidate for this simulation class.

    Parameters
    ----------
    path: str
        Path to validate.
    args: list
        Additional positional arguments.
    kwargs:
        Additional keyword arguments.

    Returns
    -------
    CandidateStatus
        Whether the path is a candidate for this simulation class.
    """
    valid = CandidateStatus.NO
    if not os.path.isdir(path):
        return CandidateStatus.NO
    fns = os.listdir(path)
    if "gizmo_parameters.txt" in fns:
        return CandidateStatus.NO
    sprefixs = ["snapdir", "snapshot"]
    opath = path
    if "output" in fns:
        opath = join(path, "output")
    folders = os.listdir(opath)
    folders = [f for f in folders if os.path.isdir(join(opath, f))]
    if any([f.startswith(k) for f in folders for k in sprefixs]):
        valid = CandidateStatus.MAYBE
    return valid

gadgetstyle

dataset

Defines the GadgetStyleSnapshot class, mostly used for deriving subclasses for related codes/simulations.

GadgetStyleSnapshot

Bases: Dataset

A dataset representing a Gadget-style snapshot.

Source code in src/scida/customs/gadgetstyle/dataset.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class GadgetStyleSnapshot(Dataset):
    """A dataset representing a Gadget-style snapshot."""

    def __init__(self, path, chunksize="auto", virtualcache=True, **kwargs) -> None:
        """We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow
        the common /PartType0, /PartType1 grouping scheme."""
        self.boxsize = np.nan
        super().__init__(path, chunksize=chunksize, virtualcache=virtualcache, **kwargs)

        defaultattributes = ["config", "header", "parameters"]
        for k in self._metadata_raw:
            name = k.strip("/").lower()
            if name in defaultattributes:
                self.__dict__[name] = self._metadata_raw[k]
                if "BoxSize" in self.__dict__[name]:
                    self.boxsize = self.__dict__[name]["BoxSize"]
                elif "Boxsize" in self.__dict__[name]:
                    self.boxsize = self.__dict__[name]["Boxsize"]

        sanity_check = kwargs.get("sanity_check", False)
        key_nparts = "NumPart_Total"
        key_nparts_hw = "NumPart_Total_HighWord"
        if sanity_check and key_nparts in self.header and key_nparts_hw in self.header:
            nparts = self.header[key_nparts_hw] * 2**32 + self.header[key_nparts]
            for i, n in enumerate(nparts):
                pkey = "PartType%i" % i
                if pkey in self.data:
                    pdata = self.data[pkey]
                    fkey = next(iter(pdata.keys()))
                    nparts_loaded = pdata[fkey].shape[0]
                    if nparts_loaded != n:
                        raise ValueError(
                            "Number of particles in header (%i) does not match number of particles loaded (%i) "
                            "for particle type %i" % (n, nparts_loaded, i)
                        )

    @classmethod
    def _get_fileprefix(cls, path: Union[str, os.PathLike], **kwargs) -> str:
        """
        Get the fileprefix used to identify files belonging to given dataset.
        Parameters
        ----------
        path: str, os.PathLike
            path to check
        kwargs

        Returns
        -------
        str
        """
        if os.path.isfile(path):
            return ""  # nothing to do, we have a single file, not a directory
        # order matters: groups will be taken before fof_subhalo, requires py>3.7 for dict order
        prfxs = ["groups", "fof_subhalo", "snap"]
        prfxs_prfx_sim = dict.fromkeys(prfxs)
        files = sorted(os.listdir(path))
        prfxs_lst = []
        for fn in files:
            s = re.search(r"^(\w*)_(\d*)", fn)
            if s is not None:
                prfxs_lst.append(s.group(1))
        prfxs_lst = [p for s in prfxs_prfx_sim for p in prfxs_lst if p.startswith(s)]
        prfxs = dict.fromkeys(prfxs_lst)
        prfxs = list(prfxs.keys())
        if len(prfxs) > 1:
            log.debug("We have more than one prefix avail: %s" % prfxs)
        elif len(prfxs) == 0:
            return ""
        if set(prfxs) == {"groups", "fof_subhalo_tab"}:
            return "groups"  # "groups" over "fof_subhalo_tab"
        return prfxs[0]

    @classmethod
    def validate_path(
        cls, path: Union[str, os.PathLike], *args, expect_grp=False, **kwargs
    ) -> CandidateStatus:
        """
        Check if path is valid for this interface.
        Parameters
        ----------
        path: str, os.PathLike
            path to check
        args
        kwargs

        Returns
        -------
        bool
        """
        path = str(path)
        possibly_valid = CandidateStatus.NO
        iszarr = path.rstrip("/").endswith(".zarr")
        if path.endswith(".hdf5") or iszarr:
            possibly_valid = CandidateStatus.MAYBE
        if os.path.isdir(path):
            files = os.listdir(path)
            sufxs = [f.split(".")[-1] for f in files]
            if not iszarr and len(set(sufxs)) > 1:
                possibly_valid = CandidateStatus.NO
            if sufxs[0] == "hdf5":
                possibly_valid = CandidateStatus.MAYBE
        if possibly_valid != CandidateStatus.NO:
            metadata_raw = load_metadata(path, **kwargs)
            # need some silly combination of attributes to be sure
            if all([k in metadata_raw for k in ["/Header"]]):
                # identifying snapshot or group catalog
                is_snap = all(
                    [
                        k in metadata_raw["/Header"]
                        for k in ["NumPart_ThisFile", "NumPart_Total"]
                    ]
                )
                is_grp = all(
                    [
                        k in metadata_raw["/Header"]
                        for k in ["Ngroups_ThisFile", "Ngroups_Total"]
                    ]
                )
                if is_grp:
                    return CandidateStatus.MAYBE
                if is_snap and not expect_grp:
                    return CandidateStatus.MAYBE
        return CandidateStatus.NO

    def register_field(self, parttype, name=None, description=""):
        """
        Register a field for a given particle type by returning through decorator.

        Parameters
        ----------
        parttype: Optional[Union[str, List[str]]]
            Particle type name to register with. If None, register for the base field container.
        name: Optional[str]
            Name of the field to register.
        description: Optional[str]
            Description of the field to register.

        Returns
        -------
        callable

        """
        res = self.data.register_field(parttype, name=name, description=description)
        return res

    def merge_data(
        self, secondobj, fieldname_suffix="", root_group: Optional[str] = None
    ):
        """
        Merge data from other snapshot into self.data.

        Parameters
        ----------
        secondobj: GadgetStyleSnapshot
        fieldname_suffix: str
        root_group: Optional[str]

        Returns
        -------
        None

        """
        data = self.data
        if root_group is not None:
            if root_group not in data._containers:
                data.add_container(root_group)
            data = self.data[root_group]
        for k in secondobj.data:
            key = k + fieldname_suffix
            if key not in data:
                data[key] = secondobj.data[k]
            else:
                log.debug("Not overwriting field '%s' during merge_data." % key)
            secondobj.data.fieldrecipes_kwargs["snap"] = self

    def merge_hints(self, secondobj):
        """
        Merge hints from other snapshot into self.hints.

        Parameters
        ----------
        secondobj: GadgetStyleSnapshot
            Other snapshot to merge hints from.

        Returns
        -------
        None
        """
        # merge hints from snap and catalog
        for h in secondobj.hints:
            if h not in self.hints:
                self.hints[h] = secondobj.hints[h]
            elif isinstance(self.hints[h], dict):
                # merge dicts
                for k in secondobj.hints[h]:
                    if k not in self.hints[h]:
                        self.hints[h][k] = secondobj.hints[h][k]
            else:
                pass  # nothing to do; we do not overwrite with catalog props

    @classmethod
    def _clean_metadata_from_raw(cls, rawmetadata):
        """
        Set metadata from raw metadata.
        """
        metadata = dict()
        if "/Header" in rawmetadata:
            header = rawmetadata["/Header"]
            if "Redshift" in header:
                metadata["redshift"] = float(header["Redshift"])
                metadata["z"] = metadata["redshift"]
            if "BoxSize" in header:
                # can be scalar or array
                metadata["boxsize"] = header["BoxSize"]
            if "Time" in header:
                metadata["time"] = float(header["Time"])
                metadata["t"] = metadata["time"]
        return metadata

    def _set_metadata(self):
        """
        Set metadata from header and config.
        """
        md = self._clean_metadata_from_raw(self._metadata_raw)
        self.metadata = md
__init__(path, chunksize='auto', virtualcache=True, **kwargs)

We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow the common /PartType0, /PartType1 grouping scheme.

Source code in src/scida/customs/gadgetstyle/dataset.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(self, path, chunksize="auto", virtualcache=True, **kwargs) -> None:
    """We define gadget-style snapshots as nbody/hydrodynamical simulation snapshots that follow
    the common /PartType0, /PartType1 grouping scheme."""
    self.boxsize = np.nan
    super().__init__(path, chunksize=chunksize, virtualcache=virtualcache, **kwargs)

    defaultattributes = ["config", "header", "parameters"]
    for k in self._metadata_raw:
        name = k.strip("/").lower()
        if name in defaultattributes:
            self.__dict__[name] = self._metadata_raw[k]
            if "BoxSize" in self.__dict__[name]:
                self.boxsize = self.__dict__[name]["BoxSize"]
            elif "Boxsize" in self.__dict__[name]:
                self.boxsize = self.__dict__[name]["Boxsize"]

    sanity_check = kwargs.get("sanity_check", False)
    key_nparts = "NumPart_Total"
    key_nparts_hw = "NumPart_Total_HighWord"
    if sanity_check and key_nparts in self.header and key_nparts_hw in self.header:
        nparts = self.header[key_nparts_hw] * 2**32 + self.header[key_nparts]
        for i, n in enumerate(nparts):
            pkey = "PartType%i" % i
            if pkey in self.data:
                pdata = self.data[pkey]
                fkey = next(iter(pdata.keys()))
                nparts_loaded = pdata[fkey].shape[0]
                if nparts_loaded != n:
                    raise ValueError(
                        "Number of particles in header (%i) does not match number of particles loaded (%i) "
                        "for particle type %i" % (n, nparts_loaded, i)
                    )
merge_data(secondobj, fieldname_suffix='', root_group=None)

Merge data from other snapshot into self.data.

Parameters:

Name Type Description Default
secondobj
required
fieldname_suffix
''
root_group Optional[str]
None

Returns:

Type Description
None
Source code in src/scida/customs/gadgetstyle/dataset.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def merge_data(
    self, secondobj, fieldname_suffix="", root_group: Optional[str] = None
):
    """
    Merge data from other snapshot into self.data.

    Parameters
    ----------
    secondobj: GadgetStyleSnapshot
    fieldname_suffix: str
    root_group: Optional[str]

    Returns
    -------
    None

    """
    data = self.data
    if root_group is not None:
        if root_group not in data._containers:
            data.add_container(root_group)
        data = self.data[root_group]
    for k in secondobj.data:
        key = k + fieldname_suffix
        if key not in data:
            data[key] = secondobj.data[k]
        else:
            log.debug("Not overwriting field '%s' during merge_data." % key)
        secondobj.data.fieldrecipes_kwargs["snap"] = self
merge_hints(secondobj)

Merge hints from other snapshot into self.hints.

Parameters:

Name Type Description Default
secondobj

Other snapshot to merge hints from.

required

Returns:

Type Description
None
Source code in src/scida/customs/gadgetstyle/dataset.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def merge_hints(self, secondobj):
    """
    Merge hints from other snapshot into self.hints.

    Parameters
    ----------
    secondobj: GadgetStyleSnapshot
        Other snapshot to merge hints from.

    Returns
    -------
    None
    """
    # merge hints from snap and catalog
    for h in secondobj.hints:
        if h not in self.hints:
            self.hints[h] = secondobj.hints[h]
        elif isinstance(self.hints[h], dict):
            # merge dicts
            for k in secondobj.hints[h]:
                if k not in self.hints[h]:
                    self.hints[h][k] = secondobj.hints[h][k]
        else:
            pass  # nothing to do; we do not overwrite with catalog props
register_field(parttype, name=None, description='')

Register a field for a given particle type by returning through decorator.

Parameters:

Name Type Description Default
parttype

Particle type name to register with. If None, register for the base field container.

required
name

Name of the field to register.

None
description

Description of the field to register.

''

Returns:

Type Description
callable
Source code in src/scida/customs/gadgetstyle/dataset.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def register_field(self, parttype, name=None, description=""):
    """
    Register a field for a given particle type by returning through decorator.

    Parameters
    ----------
    parttype: Optional[Union[str, List[str]]]
        Particle type name to register with. If None, register for the base field container.
    name: Optional[str]
        Name of the field to register.
    description: Optional[str]
        Description of the field to register.

    Returns
    -------
    callable

    """
    res = self.data.register_field(parttype, name=name, description=description)
    return res
validate_path(path, *args, expect_grp=False, **kwargs) classmethod

Check if path is valid for this interface.

Parameters:

Name Type Description Default
path Union[str, PathLike]

path to check

required
args
()
kwargs
{}

Returns:

Type Description
bool
Source code in src/scida/customs/gadgetstyle/dataset.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@classmethod
def validate_path(
    cls, path: Union[str, os.PathLike], *args, expect_grp=False, **kwargs
) -> CandidateStatus:
    """
    Check if path is valid for this interface.
    Parameters
    ----------
    path: str, os.PathLike
        path to check
    args
    kwargs

    Returns
    -------
    bool
    """
    path = str(path)
    possibly_valid = CandidateStatus.NO
    iszarr = path.rstrip("/").endswith(".zarr")
    if path.endswith(".hdf5") or iszarr:
        possibly_valid = CandidateStatus.MAYBE
    if os.path.isdir(path):
        files = os.listdir(path)
        sufxs = [f.split(".")[-1] for f in files]
        if not iszarr and len(set(sufxs)) > 1:
            possibly_valid = CandidateStatus.NO
        if sufxs[0] == "hdf5":
            possibly_valid = CandidateStatus.MAYBE
    if possibly_valid != CandidateStatus.NO:
        metadata_raw = load_metadata(path, **kwargs)
        # need some silly combination of attributes to be sure
        if all([k in metadata_raw for k in ["/Header"]]):
            # identifying snapshot or group catalog
            is_snap = all(
                [
                    k in metadata_raw["/Header"]
                    for k in ["NumPart_ThisFile", "NumPart_Total"]
                ]
            )
            is_grp = all(
                [
                    k in metadata_raw["/Header"]
                    for k in ["Ngroups_ThisFile", "Ngroups_Total"]
                ]
            )
            if is_grp:
                return CandidateStatus.MAYBE
            if is_snap and not expect_grp:
                return CandidateStatus.MAYBE
    return CandidateStatus.NO

series

Defines a series representing a Gadget-style simulation.

GadgetStyleSimulation

Bases: DatasetSeries

A series representing a Gadget-style simulation.

Source code in src/scida/customs/gadgetstyle/series.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class GadgetStyleSimulation(DatasetSeries):
    """A series representing a Gadget-style simulation."""

    def __init__(
        self,
        path,
        prefix_dict: Optional[Dict] = None,
        subpath_dict: Optional[Dict] = None,
        arg_dict: Optional[Dict] = None,
        lazy=True,
        **interface_kwargs
    ):
        """
        Initialize a GadgetStyleSimulation object.

        Parameters
        ----------
        path: str
            Path to the simulation folder, should contain "output" folder.
        prefix_dict: dict
        subpath_dict: dict
        arg_dict: dict
        lazy: bool
        interface_kwargs: dict
        """
        self.path = path
        self.name = os.path.basename(path)
        if prefix_dict is None:
            prefix_dict = dict()
        if subpath_dict is None:
            subpath_dict = dict()
        if arg_dict is None:
            arg_dict = dict()
        p = Path(path)
        if not (p.exists()):
            raise ValueError("Specified path '%s' does not exist." % path)
        paths_dict = dict()
        keys = []
        for d in [prefix_dict, subpath_dict, arg_dict]:
            keys.extend(list(d.keys()))
        keys = set(keys)
        for k in keys:
            subpath = subpath_dict.get(k, "output")
            sp = p / subpath

            prefix = _get_snapshotfolder_prefix(sp)
            prefix = prefix_dict.get(k, prefix)
            if not sp.exists():
                if k != "paths":
                    continue  # do not require optional sources
                raise ValueError("Specified path '%s' does not exist." % (p / subpath))
            fns = os.listdir(sp)
            prfxs = set([f.split("_")[0] for f in fns if f.startswith(prefix)])
            if len(prfxs) == 0:
                raise ValueError(
                    "Could not find any files with prefix '%s' in '%s'." % (prefix, sp)
                )
            prfx = prfxs.pop()

            paths = sorted([p for p in sp.glob(prfx + "_*")])
            # sometimes there are backup folders with different suffix, exclude those.
            paths = [
                p
                for p in paths
                if str(p).split("_")[-1].isdigit() or str(p).endswith(".hdf5")
            ]
            paths_dict[k] = paths

        # make sure we have the same amount of paths respectively
        length = None
        for k in paths_dict.keys():
            paths = paths_dict[k]
            if length is None:
                length = len(paths)
            else:
                assert length == len(paths)

        paths = paths_dict.pop("paths", None<