Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions src/modelarrayio/cli/cifti_to_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
from tqdm import tqdm

from modelarrayio.cli import utils as cli_utils
from modelarrayio.cli.parser_utils import add_scalar_columns_arg, add_to_modelarray_args
from modelarrayio.cli.parser_utils import add_to_modelarray_args
from modelarrayio.utils.cifti import (
_build_scalar_sources,
_cohort_to_long_dataframe,
_load_cohort_cifti,
brain_names_to_dataframe,
extract_cifti_scalar_data,
load_cohort_cifti,
)
from modelarrayio.utils.misc import build_scalar_sources, cohort_to_long_dataframe

logger = logging.getLogger(__name__)

Expand All @@ -47,7 +46,7 @@ def cifti_to_h5(
Path to a csv with demographic info and paths to data
backend : :obj:`str`
Backend to use for storage (``'hdf5'`` or ``'tiledb'``)
output : :obj:`str`
output : :obj:`pathlib.Path`
Output path. For the hdf5 backend, path to an .h5 file;
for the tiledb backend, path to a .tdb directory.
storage_dtype : :obj:`str`
Expand Down Expand Up @@ -77,19 +76,18 @@ def cifti_to_h5(
0 if successful, 1 if failed.
"""
cohort_df = pd.read_csv(cohort_file)
cohort_long = _cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns)
output_path = Path(output)
cohort_long = cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns)
if cohort_long.empty:
raise ValueError('Cohort file does not contain any scalar entries after normalization.')
scalar_sources = _build_scalar_sources(cohort_long)
scalar_sources = build_scalar_sources(cohort_long)
if not scalar_sources:
raise ValueError('Unable to derive scalar sources from cohort file.')

if backend == 'hdf5':
scalars, last_brain_names = _load_cohort_cifti(cohort_long, s3_workers)
scalars, last_brain_names = load_cohort_cifti(cohort_long, s3_workers)
greyordinate_table, structure_names = brain_names_to_dataframe(last_brain_names)
output_path = cli_utils.prepare_output_parent(output_path)
with h5py.File(output_path, 'w') as h5_file:
output = cli_utils.prepare_output_parent(output)
with h5py.File(output, 'w') as h5_file:
cli_utils.write_table_dataset(
h5_file,
'greyordinates',
Expand All @@ -107,9 +105,9 @@ def cifti_to_h5(
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)
return int(not output_path.exists())
return int(not output.exists())

output_path.mkdir(parents=True, exist_ok=True)
output.mkdir(parents=True, exist_ok=True)
if not scalar_sources:
return 0

Expand All @@ -127,7 +125,7 @@ def _process_scalar_job(scalar_name, source_files):

if rows:
cli_utils.write_tiledb_scalar_matrices(
output_path,
output,
{scalar_name: rows},
{scalar_name: source_files},
storage_dtype=storage_dtype,
Expand Down Expand Up @@ -178,5 +176,4 @@ def _parse_cifti_to_h5():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
add_to_modelarray_args(parser, default_output='greyordinatearray.h5')
add_scalar_columns_arg(parser)
return parser
2 changes: 1 addition & 1 deletion src/modelarrayio/cli/h5_to_mif.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from modelarrayio.cli import utils as cli_utils
from modelarrayio.cli.parser_utils import _is_file, add_from_modelarray_args, add_log_level_arg
from modelarrayio.utils.fixels import mif_to_nifti2, nifti2_to_mif
from modelarrayio.utils.mif import mif_to_nifti2, nifti2_to_mif

logger = logging.getLogger(__name__)

Expand Down
77 changes: 49 additions & 28 deletions src/modelarrayio/cli/mif_to_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import argparse
import logging
from collections import defaultdict
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from pathlib import Path

Expand All @@ -14,7 +15,8 @@

from modelarrayio.cli import utils as cli_utils
from modelarrayio.cli.parser_utils import _is_file, add_to_modelarray_args
from modelarrayio.utils.fixels import gather_fixels, mif_to_nifti2
from modelarrayio.utils.mif import gather_fixels, load_cohort_mif
from modelarrayio.utils.misc import cohort_to_long_dataframe

logger = logging.getLogger(__name__)

Expand All @@ -33,6 +35,7 @@ def mif_to_h5(
target_chunk_mb=2.0,
workers=None,
s3_workers=1,
scalar_columns=None,
):
"""Load all fixeldb data and write to an HDF5 or TileDB file.

Expand Down Expand Up @@ -75,25 +78,20 @@ def mif_to_h5(
"""
# gather fixel data
fixel_table, voxel_table = gather_fixels(index_file, directions_file)
output_path = Path(output)

# gather cohort data
cohort_df = pd.read_csv(cohort_file)
cohort_long = cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns)
if cohort_long.empty:
raise ValueError('Cohort file does not contain any scalar entries after normalization.')

# upload each cohort's data
scalars = defaultdict(list)
sources_lists = defaultdict(list)
logger.info('Extracting .mif data...')
for row in tqdm(cohort_df.itertuples(index=False), total=cohort_df.shape[0]):
scalar_file = row.source_file
_scalar_img, scalar_data = mif_to_nifti2(scalar_file)
scalars[row.scalar_name].append(scalar_data)
sources_lists[row.scalar_name].append(row.source_file)
scalars, sources_lists = load_cohort_mif(cohort_long, s3_workers)
if not sources_lists:
raise ValueError('Unable to derive scalar sources from cohort file.')

# Write the output
if backend == 'hdf5':
output_path = cli_utils.prepare_output_parent(output_path)
with h5py.File(output_path, 'w') as h5_file:
output = cli_utils.prepare_output_parent(output)
with h5py.File(output, 'w') as h5_file:
cli_utils.write_table_dataset(h5_file, 'fixels', fixel_table)
cli_utils.write_table_dataset(h5_file, 'voxels', voxel_table)
cli_utils.write_hdf5_scalar_matrices(
Expand All @@ -107,19 +105,42 @@ def mif_to_h5(
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)
return int(not output_path.exists())

cli_utils.write_tiledb_scalar_matrices(
output_path,
scalars,
sources_lists,
storage_dtype=storage_dtype,
compression=compression,
compression_level=compression_level,
shuffle=shuffle,
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)
return int(not output.exists())

output.mkdir(parents=True, exist_ok=True)

scalar_names = list(sources_lists.keys())
worker_count = workers if isinstance(workers, int) and workers > 0 else None
if worker_count is None:
cpu_count = os.cpu_count() or 1
worker_count = min(len(scalar_names), max(1, cpu_count))
else:
worker_count = min(len(scalar_names), worker_count)

def _write_scalar_job(scalar_name):
cli_utils.write_tiledb_scalar_matrices(
output,
{scalar_name: scalars[scalar_name]},
{scalar_name: sources_lists[scalar_name]},
storage_dtype=storage_dtype,
compression=compression,
compression_level=compression_level,
shuffle=shuffle,
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)

if worker_count <= 1:
for scalar_name in scalar_names:
_write_scalar_job(scalar_name)
else:
with ThreadPoolExecutor(max_workers=worker_count) as executor:
futures = {
executor.submit(_write_scalar_job, scalar_name): scalar_name
for scalar_name in scalar_names
}
for future in tqdm(as_completed(futures), total=len(futures), desc='TileDB scalars'):
future.result()
return 0


Expand Down
73 changes: 52 additions & 21 deletions src/modelarrayio/cli/nifti_to_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@

import argparse
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from pathlib import Path

import h5py
import nibabel as nb
import numpy as np
import pandas as pd
from tqdm import tqdm

from modelarrayio.cli import utils as cli_utils
from modelarrayio.cli.parser_utils import _is_file, add_to_modelarray_args
from modelarrayio.utils.voxels import _load_cohort_voxels
from modelarrayio.utils.misc import cohort_to_long_dataframe
from modelarrayio.utils.nifti import load_cohort_voxels

logger = logging.getLogger(__name__)

Expand All @@ -32,6 +36,7 @@ def nifti_to_h5(
target_chunk_mb=2.0,
workers=None,
s3_workers=1,
scalar_columns=None,
):
"""Load all volume data and write to an HDF5 or TileDB file.

Expand All @@ -43,7 +48,7 @@ def nifti_to_h5(
Path to a CSV with demographic info and paths to data.
backend : :obj:`str`
Storage backend (``'hdf5'`` or ``'tiledb'``).
output : :obj:`str`
output : :obj:`pathlib.Path`
Output path. For the hdf5 backend, path to an .h5 file;
for the tiledb backend, path to a .tdb directory.
storage_dtype : :obj:`str`
Expand All @@ -65,13 +70,14 @@ def nifti_to_h5(
s3_workers : :obj:`int`
Number of parallel workers for S3 downloads. Default 1.
"""
cohort_df = pd.read_csv(cohort_file)
output_path = Path(output)

group_mask_img = nb.load(group_mask_file)
group_mask_matrix = group_mask_img.get_fdata() > 0
voxel_coords = np.column_stack(np.nonzero(group_mask_matrix))

cohort_df = pd.read_csv(cohort_file)
cohort_long = cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns)
if cohort_long.empty:
raise ValueError('Cohort file does not contain any scalar entries after normalization.')
voxel_table = pd.DataFrame(
{
'voxel_id': np.arange(voxel_coords.shape[0]),
Expand All @@ -82,11 +88,13 @@ def nifti_to_h5(
)

logger.info('Extracting NIfTI data...')
scalars, sources_lists = _load_cohort_voxels(cohort_df, group_mask_matrix, s3_workers)
scalars, sources_lists = load_cohort_voxels(cohort_long, group_mask_matrix, s3_workers)
if not sources_lists:
raise ValueError('Unable to derive scalar sources from cohort file.')

if backend == 'hdf5':
output_path = cli_utils.prepare_output_parent(output_path)
with h5py.File(output_path, 'w') as h5_file:
output = cli_utils.prepare_output_parent(output)
with h5py.File(output, 'w') as h5_file:
cli_utils.write_table_dataset(h5_file, 'voxels', voxel_table)
cli_utils.write_hdf5_scalar_matrices(
h5_file,
Expand All @@ -99,19 +107,42 @@ def nifti_to_h5(
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)
return int(not output_path.exists())

cli_utils.write_tiledb_scalar_matrices(
output_path,
scalars,
sources_lists,
storage_dtype=storage_dtype,
compression=compression,
compression_level=compression_level,
shuffle=shuffle,
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)
return int(not output.exists())

output.mkdir(parents=True, exist_ok=True)

scalar_names = list(sources_lists.keys())
worker_count = workers if isinstance(workers, int) and workers > 0 else None
if worker_count is None:
cpu_count = os.cpu_count() or 1
worker_count = min(len(scalar_names), max(1, cpu_count))
else:
worker_count = min(len(scalar_names), worker_count)

def _write_scalar_job(scalar_name):
cli_utils.write_tiledb_scalar_matrices(
output,
{scalar_name: scalars[scalar_name]},
{scalar_name: sources_lists[scalar_name]},
storage_dtype=storage_dtype,
compression=compression,
compression_level=compression_level,
shuffle=shuffle,
chunk_voxels=chunk_voxels,
target_chunk_mb=target_chunk_mb,
)

if worker_count <= 1:
for scalar_name in scalar_names:
_write_scalar_job(scalar_name)
else:
with ThreadPoolExecutor(max_workers=worker_count) as executor:
futures = {
executor.submit(_write_scalar_job, scalar_name): scalar_name
for scalar_name in scalar_names
}
for future in tqdm(as_completed(futures), total=len(futures), desc='TileDB scalars'):
future.result()
return 0


Expand Down
23 changes: 10 additions & 13 deletions src/modelarrayio/cli/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def add_to_modelarray_args(parser, default_output='output.h5'):
'for the tiledb backend, path to a .tdb directory.'
),
default=default_output,
type=Path,
)
parser.add_argument(
'--scalar-columns',
'--scalar_columns',
nargs='+',
help=(
'Column names containing scalar file paths when the cohort table is in wide format. '
'If omitted, the cohort file must include "scalar_name" and "source_file" columns.'
),
)
parser.add_argument(
'--backend',
Expand Down Expand Up @@ -110,19 +120,6 @@ def add_to_modelarray_args(parser, default_output='output.h5'):
return parser


def add_scalar_columns_arg(parser):
parser.add_argument(
'--scalar-columns',
'--scalar_columns',
nargs='+',
help=(
'Column names containing scalar file paths when the cohort table is in wide format. '
"If omitted, the cohort file must include 'scalar_name' and 'source_file' columns."
),
)
return parser


def add_log_level_arg(parser):
parser.add_argument(
'--log-level',
Expand Down
Loading
Loading