diff --git a/neo/rawio/neuralynxrawio/ncssections.py b/neo/rawio/neuralynxrawio/ncssections.py index 135f04d74..016f49c89 100644 --- a/neo/rawio/neuralynxrawio/ncssections.py +++ b/neo/rawio/neuralynxrawio/ncssections.py @@ -1,4 +1,5 @@ import math +import numpy as np class NcsSections: @@ -7,7 +8,7 @@ class NcsSections: Methods of NcsSectionsFactory perform parsing of this information from an Ncs file and produce these where the sections are discontiguous in time and in temporal order. - TODO: This class will likely need __eq__, __ne__, and __hash__ to be useful in + TODO: This class will likely need __ne__ to be useful in more sophisticated segment construction algorithms. """ @@ -16,6 +17,16 @@ def __init__(self): self.sampFreqUsed = 0 # actual sampling frequency of samples self.microsPerSampUsed = 0 # microseconds per sample + def __eq__(self, other): + samp_eq = self.sampFreqUsed == other.sampFreqUsed + micros_eq = self.microsPerSampUsed == other.microsPerSampUsed + sects_eq = self.sects == other.sects + return (samp_eq and micros_eq and sects_eq) + + def __hash__(self): + return (f'{self.sampFreqUsed};{self.microsPerSampUsed};' + f'{[s.__hash__() for s in self.sects]}').__hash__() + class NcsSection: """ @@ -37,11 +48,23 @@ def __init__(self): self.endTime = -1 # end time of last record, that is, the end time of the last # sampling period contained in the last record of the section - def __init__(self, sb, st, eb, et): + def __init__(self, sb, st, eb, et, ns): self.startRec = sb self.startTime = st self.endRec = eb self.endTime = et + self.n_samples = ns + + def __eq__(self, other): + return (self.startRec == other.startRec + and self.startTime == other.startTime + and self.endRec == other.endRec + and self.endTime == other.endTime + and self.n_samples == other.n_samples) + + def __hash__(self): + s = f'{self.startRec};{self.startTime};{self.endRec};{self.endTime};{self.n_samples}' + return s.__hash__() def before_time(self, rhb): """ @@ -124,32 +147,38 @@ def _parseGivenActualFrequency(ncsMemMap, ncsSects, chanNum, reqFreq, blkOnePred NcsSections object with block locations marked """ startBlockPredTime = blkOnePredTime - blkLen = 0 + blk_len = 0 curBlock = ncsSects.sects[0] for recn in range(1, ncsMemMap.shape[0]): - if ncsMemMap['channel_id'][recn] != chanNum or \ - ncsMemMap['sample_rate'][recn] != reqFreq: + timestamp = ncsMemMap['timestamp'][recn] + channel_id = ncsMemMap['channel_id'][recn] + sample_rate = ncsMemMap['sample_rate'][recn] + nb_valid = ncsMemMap['nb_valid'][recn] + + if channel_id != chanNum or sample_rate != reqFreq: raise IOError('Channel number or sampling frequency changed in ' + 'records within file') predTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed, - startBlockPredTime, blkLen) - ts = ncsMemMap['timestamp'][recn] - nValidSamps = ncsMemMap['nb_valid'][recn] - if ts != predTime: + startBlockPredTime, blk_len) + nValidSamps = nb_valid + if timestamp != predTime: curBlock.endRec = recn - 1 curBlock.endTime = predTime - curBlock = NcsSection(recn, ts, -1, -1) + curBlock.n_samples = blk_len + curBlock = NcsSection(recn, timestamp, -1, -1, -1) ncsSects.sects.append(curBlock) startBlockPredTime = NcsSectionsFactory.calc_sample_time( - ncsSects.sampFreqUsed, ts, nValidSamps) - blkLen = 0 + ncsSects.sampFreqUsed, + timestamp, + nValidSamps) + blk_len = 0 else: - blkLen += nValidSamps + blk_len += nValidSamps curBlock.endRec = ncsMemMap.shape[0] - 1 endTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed, startBlockPredTime, - blkLen) + blk_len) curBlock.endTime = endTime return ncsSects @@ -199,7 +228,8 @@ def _buildGivenActualFrequency(ncsMemMap, actualSampFreq, reqFreq): ncsMemMap['sample_rate'][lastBlkI] == reqFreq and \ lts == predLastBlockStartTime: lastBlkEndTime = NcsSectionsFactory.calc_sample_time(actualSampFreq, lts, lnb) - curBlock = NcsSection(0, ts0, lastBlkI, lastBlkEndTime) + n_samples = NcsSection._RECORD_SIZE * lastBlkI + curBlock = NcsSection(0, ts0, lastBlkI, lastBlkEndTime, n_samples) nb.sects.append(curBlock) return nb @@ -207,7 +237,7 @@ def _buildGivenActualFrequency(ncsMemMap, actualSampFreq, reqFreq): # otherwise need to scan looking for breaks else: blkOnePredTime = NcsSectionsFactory.calc_sample_time(actualSampFreq, ts0, nb0) - curBlock = NcsSection(0, ts0, -1, -1) + curBlock = NcsSection(0, ts0, -1, -1, -1) nb.sects.append(curBlock) return NcsSectionsFactory._parseGivenActualFrequency(ncsMemMap, nb, chanNum, reqFreq, blkOnePredTime) @@ -233,60 +263,72 @@ def _parseForMaxGap(ncsMemMap, ncsSects, maxGapLen): largest block """ - # track frequency of each block and use estimate with longest block - maxBlkLen = 0 - maxBlkFreqEstimate = 0 - - # Parse the record sequence, finding blocks of continuous time with no more than - # maxGapLength and same channel number chanNum = ncsMemMap['channel_id'][0] - - startBlockTime = ncsMemMap['timestamp'][0] - blkLen = ncsMemMap['nb_valid'][0] - lastRecTime = startBlockTime - lastRecNumSamps = blkLen recFreq = ncsMemMap['sample_rate'][0] - curBlock = NcsSection(0, startBlockTime, -1, -1) - ncsSects.sects.append(curBlock) - for recn in range(1, ncsMemMap.shape[0]): - if ncsMemMap['channel_id'][recn] != chanNum or \ - ncsMemMap['sample_rate'][recn] != recFreq: - raise IOError('Channel number or sampling frequency changed in ' + - 'records within file') - predTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed, lastRecTime, - lastRecNumSamps) - ts = ncsMemMap['timestamp'][recn] - nb = ncsMemMap['nb_valid'][recn] - if abs(ts - predTime) > maxGapLen: - curBlock.endRec = recn - 1 - curBlock.endTime = predTime - curBlock = NcsSection(recn, ts, -1, -1) - ncsSects.sects.append(curBlock) - if blkLen > maxBlkLen: - maxBlkLen = blkLen - maxBlkFreqEstimate = (blkLen - lastRecNumSamps) * 1e6 / \ - (lastRecTime - startBlockTime) - startBlockTime = ts - blkLen = nb - else: - blkLen += nb - lastRecTime = ts - lastRecNumSamps = nb - - if blkLen > maxBlkLen: - maxBlkFreqEstimate = (blkLen - lastRecNumSamps) * 1e6 / \ - (lastRecTime - startBlockTime) - - curBlock.endRec = ncsMemMap.shape[0] - 1 - endTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed, lastRecTime, - lastRecNumSamps) - curBlock.endTime = endTime + # check for consistent channel_ids and sampling rates + ncsMemMap['channel_id'] + if not (ncsMemMap['channel_id'] == chanNum).all(): + raise IOError('Channel number changed in records within file') + + if not all(ncsMemMap['sample_rate'] == recFreq): + raise IOError('Sampling frequency changed in records within file') + + # find most frequent number of samples + exp_nb_valid = np.argmax(np.bincount(ncsMemMap['nb_valid'])) + # detect records with incomplete number of samples + gap_rec_ids = list(np.where(ncsMemMap['nb_valid'] != exp_nb_valid)[0]) + + rec_duration = 1e6 / ncsSects.sampFreqUsed * ncsMemMap['nb_valid'] + pred_times = np.rint(ncsMemMap['timestamp'] + rec_duration).astype(np.int64) + max_pred_times = pred_times + maxGapLen + # data records that start later than the predicted time (including the + # maximal accepted gap length) are considered delayed and a gap is + # registered. + delayed_recs = list(np.where(max_pred_times[:-1] < ncsMemMap['timestamp'][1:])[0]) + gap_rec_ids.extend(delayed_recs) + + # cleaning extracted gap ids + # last record can not be the beginning of a gap + last_rec_id = len(ncsMemMap['timestamp']) - 1 + if last_rec_id in gap_rec_ids: + gap_rec_ids.remove(last_rec_id) + + # gap ids can only be listed once + gap_rec_ids = sorted(set(gap_rec_ids)) + + # create recording segments from identified gaps + ncsSects.sects.append(NcsSection(0, ncsMemMap['timestamp'][0], -1, -1, -1)) + for gap_rec_id in gap_rec_ids: + curr_sec = ncsSects.sects[-1] + curr_sec.endRec = gap_rec_id + curr_sec.endTime = pred_times[gap_rec_id] + n_samples = np.sum(ncsMemMap['nb_valid'][curr_sec.startRec:gap_rec_id + 1]) + curr_sec.n_samples = n_samples + + next_sec = NcsSection(gap_rec_id + 1, + ncsMemMap['timestamp'][gap_rec_id + 1], -1, -1, -1) + ncsSects.sects.append(next_sec) + + curr_sec = ncsSects.sects[-1] + curr_sec.endRec = len(ncsMemMap['timestamp']) - 1 + curr_sec.endTime = pred_times[-1] + n_samples = np.sum(ncsMemMap['nb_valid'][curr_sec.startRec:]) + curr_sec.n_samples = n_samples + + # calculate the estimated frequency of the block with the most samples + max_blk_idx = np.argmax([bl.endRec - bl.startRec for bl in ncsSects.sects]) + max_blk = ncsSects.sects[max_blk_idx] + + maxBlkFreqEstimate = (max_blk.n_samples - ncsMemMap['nb_valid'][max_blk.endRec]) * 1e6 / \ + (ncsMemMap['timestamp'][max_blk.endRec] - max_blk.startTime) ncsSects.sampFreqUsed = maxBlkFreqEstimate ncsSects.microsPerSampUsed = NcsSectionsFactory.get_micros_per_samp_for_freq( maxBlkFreqEstimate) - + # free memory that is unnecessarily occupied by the memmap + # (see https://github.com/numpy/numpy/issues/19340) + del ncsMemMap return ncsSects @staticmethod @@ -325,7 +367,7 @@ def _buildForMaxGap(ncsMemMap, nomFreq): freqInFile = math.floor(nomFreq) if lts - predLastBlockStartTime == 0 and lcid == chanNum and lsr == freqInFile: endTime = NcsSectionsFactory.calc_sample_time(nomFreq, lts, lnb) - curBlock = NcsSection(0, ts0, lastBlkI, endTime) + curBlock = NcsSection(0, ts0, lastBlkI, endTime, numSampsForPred) nb.sects.append(curBlock) nb.sampFreqUsed = numSampsForPred / (lts - ts0) * 1e6 nb.microsPerSampUsed = NcsSectionsFactory.get_micros_per_samp_for_freq(nb.sampFreqUsed) diff --git a/neo/rawio/neuralynxrawio/neuralynxrawio.py b/neo/rawio/neuralynxrawio/neuralynxrawio.py index 4d8a37341..1ddd47de9 100644 --- a/neo/rawio/neuralynxrawio/neuralynxrawio.py +++ b/neo/rawio/neuralynxrawio/neuralynxrawio.py @@ -47,6 +47,8 @@ import numpy as np import os +import pathlib +import copy from collections import (namedtuple, OrderedDict) from neo.rawio.neuralynxrawio.ncssections import (NcsSection, NcsSectionsFactory) @@ -110,6 +112,9 @@ def _source_name(self): else: return self.dirname + # from memory_profiler import profile + # + # @profile() def _parse_header(self): stream_channels = [] @@ -139,6 +144,11 @@ def _parse_header(self): filenames = sorted(os.listdir(self.dirname)) dirname = self.dirname else: + if not os.path.isfile(self.filename): + raise ValueError(f'Provided Filename is not a file: ' + f'{self.filename}. If you want to provide a ' + f'directory use the `dirname` keyword') + dirname, fname = os.path.split(self.filename) filenames = [fname] @@ -209,15 +219,7 @@ def _parse_header(self): 'Several nse or ntt files have the same unit_id!!!' self.nse_ntt_filenames[chan_uid] = filename - dtype = get_nse_or_ntt_dtype(info, ext) - - if os.path.getsize(filename) <= NlxHeader.HEADER_SIZE: - self._empty_nse_ntt.append(filename) - data = np.zeros((0,), dtype=dtype) - else: - data = np.memmap(filename, dtype=dtype, mode='r', - offset=NlxHeader.HEADER_SIZE) - + data = self._get_file_map(filename) self._spike_memmap[chan_uid] = data unit_ids = np.unique(data['unit_id']) @@ -249,8 +251,7 @@ def _parse_header(self): data = np.zeros((0,), dtype=nev_dtype) internal_ids = [] else: - data = np.memmap(filename, dtype=nev_dtype, mode='r', - offset=NlxHeader.HEADER_SIZE) + data = self._get_file_map(filename) internal_ids = np.unique(data[['event_id', 'ttl_input']]).tolist() for internal_event_id in internal_ids: if internal_event_id not in self.internal_event_ids: @@ -378,6 +379,37 @@ def _parse_header(self): # ~ ev_ann['digital_marker'] = # ~ ev_ann['analog_marker'] = + def _get_file_map(self, filename): + """ + Create memory maps when needed + see also https://github.com/numpy/numpy/issues/19340 + """ + filename = pathlib.Path(filename) + suffix = filename.suffix.lower()[1:] + + if suffix == 'ncs': + return np.memmap(filename, dtype=self._ncs_dtype, mode='r', + offset=NlxHeader.HEADER_SIZE) + + elif suffix in ['nse', 'ntt']: + info = NlxHeader(filename) + dtype = get_nse_or_ntt_dtype(info, suffix) + + # return empty map if file does not contain data + if os.path.getsize(filename) <= NlxHeader.HEADER_SIZE: + self._empty_nse_ntt.append(filename) + return np.zeros((0,), dtype=dtype) + + return np.memmap(filename, dtype=dtype, mode='r', + offset=NlxHeader.HEADER_SIZE) + + elif suffix == 'nev': + return np.memmap(filename, dtype=nev_dtype, mode='r', + offset=NlxHeader.HEADER_SIZE) + + else: + raise ValueError(f'Unknown file suffix {suffix}') + # Accessors for segment times which are offset by appropriate global start time def _segment_t_start(self, block_index, seg_index): return self._seg_t_starts[seg_index] - self.global_t_start @@ -565,16 +597,15 @@ def scan_ncs_files(self, ncs_filenames): chanSectMap = dict() for chan_uid, ncs_filename in self.ncs_filenames.items(): - data = np.memmap(ncs_filename, dtype=self._ncs_dtype, mode='r', - offset=NlxHeader.HEADER_SIZE) + data = self._get_file_map(ncs_filename) nlxHeader = NlxHeader(ncs_filename) if not chanSectMap or (chanSectMap and not NcsSectionsFactory._verifySectionsStructure(data, lastNcsSections)): lastNcsSections = NcsSectionsFactory.build_for_ncs_file(data, nlxHeader) - - chanSectMap[chan_uid] = [lastNcsSections, nlxHeader, data] + chanSectMap[chan_uid] = [lastNcsSections, nlxHeader, ncs_filename] + del data # Construct an inverse dictionary from NcsSections to list of associated chan_uids revSectMap = dict() @@ -584,8 +615,8 @@ def scan_ncs_files(self, ncs_filenames): # If there is only one NcsSections structure in the set of ncs files, there should only # be one entry. Otherwise this is presently unsupported. if len(revSectMap) > 1: - raise IOError('ncs files have {} different sections structures. Unsupported.'.format( - len(revSectMap))) + raise IOError(f'ncs files have {len(revSectMap)} different sections ' + f'structures. Unsupported configuration.') seg_time_limits = SegmentTimeLimits(nb_segment=len(lastNcsSections.sects), t_start=[], t_stop=[], length=[], @@ -595,7 +626,7 @@ def scan_ncs_files(self, ncs_filenames): # create segment with subdata block/t_start/t_stop/length for each channel for i, fileEntry in enumerate(self.ncs_filenames.items()): chan_uid = fileEntry[0] - data = chanSectMap[chan_uid][2] + data = self._get_file_map(chanSectMap[chan_uid][2]) # create a memmap for each record section of the current file curSects = chanSectMap[chan_uid][0] diff --git a/neo/test/iotest/test_nixio_fr.py b/neo/test/iotest/test_nixio_fr.py index 9adc3073b..d57b18ae2 100644 --- a/neo/test/iotest/test_nixio_fr.py +++ b/neo/test/iotest/test_nixio_fr.py @@ -112,10 +112,10 @@ def test_annotations(self): sp = SpikeTrain([3, 4, 5]* s, t_stop=10.0) sp.annotations['railway'] = 'hello train' ev = Event(np.arange(0, 30, 10)*pq.Hz, - labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S')) + labels=np.array(['trig0', 'trig1', 'trig2'], dtype='U')) ev.annotations['venue'] = 'hello event' ev2 = Event(np.arange(0, 30, 10) * pq.Hz, - labels=np.array(['trig0', 'trig1', 'trig2'], dtype='S')) + labels=np.array(['trig0', 'trig1', 'trig2'], dtype='U')) ev2.annotations['evven'] = 'hello ev' seg.spiketrains.append(sp) seg.events.append(ev) diff --git a/neo/test/rawiotest/test_neuralynxrawio.py b/neo/test/rawiotest/test_neuralynxrawio.py index e2895a25c..c18268ad8 100644 --- a/neo/test/rawiotest/test_neuralynxrawio.py +++ b/neo/test/rawiotest/test_neuralynxrawio.py @@ -4,7 +4,8 @@ from neo.rawio.neuralynxrawio.neuralynxrawio import NeuralynxRawIO from neo.rawio.neuralynxrawio.nlxheader import NlxHeader -from neo.rawio.neuralynxrawio.ncssections import (NcsSections, NcsSectionsFactory) +from neo.rawio.neuralynxrawio.ncssections import (NcsSection, NcsSections, + NcsSectionsFactory) from neo.test.rawiotest.common_rawio_test import BaseTestRawIO import logging @@ -278,5 +279,42 @@ def test_block_verify(self): self.assertTrue(NcsSectionsFactory._verifySectionsStructure(data1, nb1)) +class TestNcsSections(TestNeuralynxRawIO, unittest.TestCase): + """ + Test building NcsBlocks for files of different revisions. + """ + entities_to_test = [] + + def test_equality(self): + ns0 = NcsSections() + ns1 = NcsSections() + + ns0.microsPerSampUsed = 1 + ns1.microsPerSampUsed = 1 + ns0.sampFreqUsed = 300 + ns1.sampFreqUsed = 300 + + self.assertEqual(ns0, ns1) + + # add sections + ns0.sects = [NcsSection(0, 0, 100, 100, 10)] + ns1.sects = [NcsSection(0, 0, 100, 100, 10)] + + self.assertEqual(ns0, ns1) + + # check inequality for different attributes + # different number of sections + ns0.sects.append(NcsSection(0, 0, 100, 100, 10)) + self.assertNotEqual(ns0, ns1) + + # different section attributes + ns0.sects = [NcsSection(0, 0, 200, 200, 10)] + self.assertNotEqual(ns0, ns1) + + # different attributes + ns0.sampFreqUsed = 400 + self.assertNotEqual(ns0, ns1) + + if __name__ == "__main__": unittest.main()