From 10f0cc0fc57c32ae5eb7f87bcad9c17fc3cff89b Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 27 Feb 2019 00:11:22 +0100 Subject: [PATCH 01/15] Added the possibility to merge multiple spiketrains into an existing one with spiketrain.merge while preserving the previous functionality for merging a pair of spiketrains. --- neo/core/baseneo.py | 30 ++++++------ neo/core/spiketrain.py | 73 ++++++++++++++++------------ neo/test/coretest/test_spiketrain.py | 4 +- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/neo/core/baseneo.py b/neo/core/baseneo.py index fe44905e6..aa739df3a 100644 --- a/neo/core/baseneo.py +++ b/neo/core/baseneo.py @@ -90,7 +90,7 @@ def merge_annotation(a, b): return a -def merge_annotations(A, B): +def merge_annotations(A, *Bs): """ Merge two sets of annotations. @@ -102,21 +102,19 @@ def merge_annotations(A, B): For strings: concatenate with ';' Otherwise: warn if the annotations are not equal """ - merged = {} - for name in A: - if name in B: - try: - merged[name] = merge_annotation(A[name], B[name]) - except BaseException as exc: - # exc.args += ('key %s' % name,) - # raise - merged[name] = "MERGE CONFLICT" # temporary hack - else: - merged[name] = A[name] - for name in B: - if name not in merged: - merged[name] = B[name] - logger.debug("Merging annotations: A=%s B=%s merged=%s", A, B, merged) + merged = A.copy() + for B in Bs: + for name in B: + if name not in merged: + merged[name] = B[name] + else: + try: + merged[name] = merge_annotation(merged[name], B[name]) + except BaseException as exc: + # exc.args += ('key %s' % name,) + # raise + merged[name] = "MERGE CONFLICT" # temporary hack + logger.debug("Merging annotations: A=%s Bs=%s merged=%s", A, Bs, merged) return merged diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 0a59b6ecc..f28d801cd 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -618,45 +618,54 @@ def time_slice(self, t_start, t_stop): return new_st - def merge(self, other): + def merge(self, *others): ''' - Merge another :class:`SpikeTrain` into this one. + Merge other :class:`SpikeTrain` objects into this one. The times of the :class:`SpikeTrain` objects combined in one array and sorted. - If the attributes of the two :class:`SpikeTrain` are not + If the attributes of the :class:`SpikeTrain` objects are not compatible, an Exception is raised. ''' - if self.sampling_rate != other.sampling_rate: - raise MergeError("Cannot merge, different sampling rates") - if self.t_start != other.t_start: - raise MergeError("Cannot merge, different t_start") - if self.t_stop != other.t_stop: - raise MemoryError("Cannot merge, different t_stop") - if self.left_sweep != other.left_sweep: - raise MemoryError("Cannot merge, different left_sweep") - if self.segment != other.segment: - raise MergeError("Cannot merge these two signals as they belong to" - " different segments.") + # TODO: unittests for merging multiple spiketrains + # TODO: maybe allow for lists of spiketrains as input? would be convenient in many use cases + # TODO: improve! better way to loop over self AND others? + for other in others: + if self.sampling_rate != other.sampling_rate: + raise MergeError("Cannot merge, different sampling rates") + if self.t_start != other.t_start: + raise MergeError("Cannot merge, different t_start") + if self.t_stop != other.t_stop: + raise MemoryError("Cannot merge, different t_stop") + if self.left_sweep != other.left_sweep: + raise MemoryError("Cannot merge, different left_sweep") + if self.segment != other.segment: + raise MergeError("Cannot merge these signals as they belong to" + " different segments.") + if other.units != self.units: + other = other.rescale(self.units) if hasattr(self, "lazy_shape"): - if hasattr(other, "lazy_shape"): - merged_lazy_shape = (self.lazy_shape[0] + other.lazy_shape[0]) - else: - raise MergeError("Cannot merge a lazy object with a real" - " object.") - if other.units != self.units: - other = other.rescale(self.units) - wfs = [self.waveforms is not None, other.waveforms is not None] + merged_lazy_shape = self.lazy_shape[0] + for other in others: + if hasattr(other, "lazy_shape"): + merged_lazy_shape[0] += other.lazy_shape[0] + else: + raise MergeError("Cannot merge a lazy object with a real" + " object.") + all_spiketrains = [self] + all_spiketrains.extend(others) + wfs = [st.waveforms is not None for st in all_spiketrains] if any(wfs) and not all(wfs): raise MergeError("Cannot merge signal with waveform and signal " "without waveform.") - stack = np.concatenate((np.asarray(self), np.asarray(other))) + stack = np.concatenate([np.asarray(st) for st in all_spiketrains]) sorting = np.argsort(stack) stack = stack[sorting] + kwargs = {} - kwargs['array_annotations'] = self._merge_array_annotations(other, sorting=sorting) + kwargs['array_annotations'] = self._merge_array_annotations(others, sorting=sorting) for name in ("name", "description", "file_origin"): attr_self = getattr(self, name) @@ -672,7 +681,7 @@ def merge(self, other): t_start=self.t_start, t_stop=self.t_stop, sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs) if all(wfs): - wfs_stack = np.vstack((self.waveforms, other.waveforms)) + wfs_stack = np.vstack([st.waveforms for st in all_spiketrains]) wfs_stack = wfs_stack[sorting] train.waveforms = wfs_stack train.segment = self.segment @@ -683,12 +692,12 @@ def merge(self, other): train.lazy_shape = merged_lazy_shape return train - def _merge_array_annotations(self, other, sorting=None): + def _merge_array_annotations(self, others, sorting=None): ''' - Merges array annotations of 2 different objects. + Merges array annotations of multiple different objects. The merge happens in such a way that the result fits the merged data - In general this means concatenating the arrays from the 2 objects. - If an annotation is only present in one of the objects, it will be omitted. + In general this means concatenating the arrays from the objects. + If an annotation is not present in one of the objects, it will be omitted. Apart from that the array_annotations need to be sorted according to the sorting of the spikes. :return Merged array_annotations @@ -704,7 +713,7 @@ def _merge_array_annotations(self, other, sorting=None): for key in keys: try: self_ann = copy.deepcopy(self.array_annotations[key]) - other_ann = copy.deepcopy(other.array_annotations[key]) + other_ann = np.concatenate([copy.deepcopy(other.array_annotations[key]) for other in others]) if isinstance(self_ann, pq.Quantity): other_ann.rescale(self_ann.units) arr_ann = np.concatenate([self_ann, other_ann]) * self_ann.units @@ -717,13 +726,13 @@ def _merge_array_annotations(self, other, sorting=None): omitted_keys_self.append(key) continue - omitted_keys_other = [key for key in other.array_annotations if + omitted_keys_other = [key for key in np.unique([key for other in others for key in other.array_annotations]) if key not in self.array_annotations] if omitted_keys_self or omitted_keys_other: warnings.warn("The following array annotations were omitted, because they were only " "present in one of the merged objects: {} from the one that was merged " - "into and {} from the one that was merged into the other" + "into and {} from the ones that were merged into it." "".format(omitted_keys_self, omitted_keys_other), UserWarning) return merged_array_annotations diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index dc754109f..7b949bbf4 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -1138,8 +1138,8 @@ def test_merge_typical(self): "omitted, because they were only present" " in one of the merged objects: " "['label'] from the one that was merged " - "into and ['label2'] from the one that " - "was merged into the other") + "into and ['label2'] from the ones that " + "were merged into it.") assert_neo_object_is_compliant(result) From e684346901df10db254f0d529eaf74987b66f6c5 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 13:27:09 +0100 Subject: [PATCH 02/15] Fixed typos resulting in undesired error types --- neo/core/spiketrain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index f28d801cd..7e2d13468 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -637,9 +637,9 @@ def merge(self, *others): if self.t_start != other.t_start: raise MergeError("Cannot merge, different t_start") if self.t_stop != other.t_stop: - raise MemoryError("Cannot merge, different t_stop") + raise MergeError("Cannot merge, different t_stop") if self.left_sweep != other.left_sweep: - raise MemoryError("Cannot merge, different left_sweep") + raise MergeError("Cannot merge, different left_sweep") if self.segment != other.segment: raise MergeError("Cannot merge these signals as they belong to" " different segments.") From 945e9c3262b99f7583dc5b07ff2af6c0f16635e5 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 13:28:07 +0100 Subject: [PATCH 03/15] Fixed waveform units being reset to dimensionless --- neo/core/spiketrain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 7e2d13468..e60ceea8e 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -681,8 +681,8 @@ def merge(self, *others): t_start=self.t_start, t_stop=self.t_stop, sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs) if all(wfs): - wfs_stack = np.vstack([st.waveforms for st in all_spiketrains]) - wfs_stack = wfs_stack[sorting] + wfs_stack = np.vstack([st.waveforms.rescale(self.waveforms.units) for st in all_spiketrains]) + wfs_stack = wfs_stack[sorting] * self.waveforms.units train.waveforms = wfs_stack train.segment = self.segment if train.segment is not None: From 7f30f88f2cfc18cf9852aab6517124ea049b9131 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 13:29:26 +0100 Subject: [PATCH 04/15] Added functionality to merge multiple BaseNeo objects and to merge annotations of multiple objects --- neo/core/baseneo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/neo/core/baseneo.py b/neo/core/baseneo.py index aa739df3a..5352b2b20 100644 --- a/neo/core/baseneo.py +++ b/neo/core/baseneo.py @@ -365,7 +365,7 @@ def _all_attrs(self): """ return self._necessary_attrs + self._recommended_attrs - def merge_annotations(self, other): + def merge_annotations(self, *others): """ Merge annotations from the other object into this one. @@ -377,14 +377,15 @@ def merge_annotations(self, other): For strings: concatenate with ';' Otherwise: fail if the annotations are not equal """ + other_annotations = [other.annotations for other in others] merged_annotations = merge_annotations(self.annotations, - other.annotations) + *other_annotations) self.annotations.update(merged_annotations) - def merge(self, other): + def merge(self, *others): """ Merge the contents of another object into this one. See :meth:`merge_annotations` for details of the merge operation. """ - self.merge_annotations(other) + self.merge_annotations(*others) From 13859c1755db2a49e1dfa4863415a6fdb8da72d2 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 13:30:42 +0100 Subject: [PATCH 05/15] Added unittests for merging multiple BaseNeo objects and for merging annotations of multiple neo objects --- neo/test/coretest/test_base.py | 190 +++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/neo/test/coretest/test_base.py b/neo/test/coretest/test_base.py index 0544537f8..49f1e5a3c 100644 --- a/neo/test/coretest/test_base.py +++ b/neo/test/coretest/test_base.py @@ -139,10 +139,13 @@ class Test_BaseNeo_merge_annotations_merge(unittest.TestCase): def setUp(self): self.name1 = 'a base 1' self.name2 = 'a base 2' + self.name3 = 'a base 3' self.description1 = 'this is a test 1' self.description2 = 'this is a test 2' + self.description3 = 'this is a test 3' self.base1 = BaseNeo(name=self.name1, description=self.description1) self.base2 = BaseNeo(name=self.name2, description=self.description2) + self.base3 = BaseNeo(name=self.name3, description=self.description3) def test_merge_annotations__dict(self): self.base1.annotations = {'val0': 'val0', 'val1': 1, @@ -184,6 +187,57 @@ def test_merge_annotations__dict(self): self.assertEqual(self.description1, self.base1.description) self.assertEqual(self.description2, self.base2.description) + def test_merge_multiple_annotations__dict(self): + self.base1.annotations = {'val0': 'val0', 'val1': 1, + 'val2': 2.2, 'val3': 'test1', + 'val4': [.4], 'val5': {0: 0, 1: {0: 0}}, + 'val6': np.array([0, 1, 2])} + self.base2.annotations = {'val2': 2.2, 'val3': 'test2', + 'val4': [4, 4.4], 'val5': {1: {1: 1}, 2: 2}, + 'val6': np.array([4, 5, 6]), 'val7': True} + self.base3.annotations = {'val2': 2.2, 'val3': 'test3', + 'val4': [44], 'val5': {1: {2: 2}, 2: 2, 3:3}, + 'val6': np.array([8, 9, 10]), 'val8': False} + + ann1 = self.base1.annotations + ann2 = self.base2.annotations + ann3 = self.base3.annotations + ann1c = self.base1.annotations.copy() + ann2c = self.base2.annotations.copy() + ann3c = self.base3.annotations.copy() + + targ = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1;test2;test3', + 'val4': [.4, 4, 4.4, 44], 'val5': {0: 0, 1: {0: 0, 1: 1, 2: 2}, 2: 2, 3: 3}, + 'val7': True, 'val8': False} + + self.base1.merge_annotations(self.base2, self.base3) + + val6t = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10]) + val61 = ann1.pop('val6') + val61c = ann1c.pop('val6') + val62 = ann2.pop('val6') + val62c = ann2c.pop('val6') + val63 = ann3.pop('val6') + val63c = ann3c.pop('val6') + + self.assertEqual(ann1, self.base1.annotations) + self.assertNotEqual(ann1c, self.base1.annotations) + self.assertEqual(ann2c, self.base2.annotations) + self.assertEqual(ann3c, self.base3.annotations) + self.assertEqual(targ, self.base1.annotations) + + assert_arrays_equal(val61, val6t) + self.assertRaises(AssertionError, assert_arrays_equal, val61c, val6t) + assert_arrays_equal(val62, val62c) + assert_arrays_equal(val63, val63c) + + self.assertEqual(self.name1, self.base1.name) + self.assertEqual(self.name2, self.base2.name) + self.assertEqual(self.name3, self.base3.name) + self.assertEqual(self.description1, self.base1.description) + self.assertEqual(self.description2, self.base2.description) + self.assertEqual(self.description3, self.base3.description) + def test_merge_annotations__func__dict(self): ann1 = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1', 'val4': [.4], 'val5': {0: 0, 1: {0: 0}}, @@ -217,6 +271,47 @@ def test_merge_annotations__func__dict(self): assert_arrays_equal(val61, val61c) assert_arrays_equal(val62, val62c) + def test_merge_multiple_annotations__func__dict(self): + ann1 = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1', + 'val4': [.4], 'val5': {0: 0, 1: {0: 0}}, + 'val6': np.array([0, 1, 2])} + ann2 = {'val2': 2.2, 'val3': 'test2', + 'val4': [4, 4.4], 'val5': {1: {1: 1}, 2: 2}, + 'val6': np.array([4, 5, 6]), 'val7': True} + ann3 = {'val2': 2.2, 'val3': 'test3', + 'val4': [44], 'val5': {1: {2: 2}, 2: 2, 3: 3}, + 'val6': np.array([8, 9, 10]), 'val8': False} + + ann1c = ann1.copy() + ann2c = ann2.copy() + ann3c = ann3.copy() + + targ = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1;test2;test3', + 'val4': [.4, 4, 4.4, 44], 'val5': {0: 0, 1: {0: 0, 1: 1, 2: 2}, 2: 2, 3: 3}, + 'val7': True, 'val8': False} + + res = merge_annotations(ann1, ann2, ann3) + + val6t = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10]) + val6r = res.pop('val6') + val61 = ann1.pop('val6') + val61c = ann1c.pop('val6') + val62 = ann2.pop('val6') + val62c = ann2c.pop('val6') + val63 = ann3.pop('val6') + val63c = ann3c.pop('val6') + + self.assertEqual(ann1, ann1c) + self.assertEqual(ann2, ann2c) + self.assertEqual(ann3, ann3c) + self.assertEqual(res, targ) + + assert_arrays_equal(val6r, val6t) + self.assertRaises(AssertionError, assert_arrays_equal, val61, val6t) + assert_arrays_equal(val61, val61c) + assert_arrays_equal(val62, val62c) + assert_arrays_equal(val63, val63c) + def test_merge_annotation__func__str(self): ann1 = 'test1' ann2 = 'test2' @@ -338,6 +433,37 @@ def test_merge__dict(self): self.assertEqual(self.description1, self.base1.description) self.assertEqual(self.description2, self.base2.description) + def test_merge_multiple__dict(self): + self.base1.annotations = {'val0': 'val0', 'val1': 1, + 'val2': 2.2, 'val3': 'test1'} + self.base2.annotations = {'val2': 2.2, 'val3': 'test2', + 'val4': [4, 4.4], 'val5': True} + self.base3.annotations = {'val2': 2.2, 'val3': 'test3', + 'val4': [44], 'val5': True, 'val6': False} + + ann1 = self.base1.annotations + ann1c = self.base1.annotations.copy() + ann2c = self.base2.annotations.copy() + ann3c = self.base3.annotations.copy() + + targ = {'val0': 'val0', 'val1': 1, 'val2': 2.2, 'val3': 'test1;test2;test3', + 'val4': [4, 4.4, 44], 'val5': True, 'val6': False} + + self.base1.merge(self.base2, self.base3) + + self.assertEqual(ann1, self.base1.annotations) + self.assertNotEqual(ann1c, self.base1.annotations) + self.assertEqual(ann2c, self.base2.annotations) + self.assertEqual(ann3c, self.base3.annotations) + self.assertEqual(targ, self.base1.annotations) + + self.assertEqual(self.name1, self.base1.name) + self.assertEqual(self.name2, self.base2.name) + self.assertEqual(self.name3, self.base3.name) + self.assertEqual(self.description1, self.base1.description) + self.assertEqual(self.description2, self.base2.description) + self.assertEqual(self.description3, self.base3.description) + def test_merge_annotations__different_type_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -350,6 +476,22 @@ def test_merge_annotations__different_type_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple_annotations__different_type_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': True} + self.base3.annotations = {'val5': 1, 'val6': 79, + 'val7': True} + self.base1.merge_annotations(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': 79, + 'val7': True}) + def test_merge__different_type_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -362,6 +504,22 @@ def test_merge__different_type_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple__different_type_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': 'tester'} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': True} + self.base3.annotations = {'val5': 3.1, 'val6': False, + 'val7': 'val7'} + self.base1.merge(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': False, + 'val7': 'val7'}) + def test_merge_annotations__unmergable_unequal_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -374,6 +532,22 @@ def test_merge_annotations__unmergable_unequal_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple_annotations__unmergable_unequal_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': 3.5} + self.base3.annotations = {'val5': 3.4, 'val6': [4, 4.4], + 'val7': True} + self.base1.merge_annotations(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': [4, 4.4], + 'val7': True}) + def test_merge__unmergable_unequal_AssertionError(self): self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} self.base2.annotations = {'val3': False, 'val4': [4, 4.4], @@ -386,6 +560,22 @@ def test_merge__unmergable_unequal_AssertionError(self): 'val4': [4, 4.4], 'val5': True}) + def test_merge_multiple__unmergable_unequal_AssertionError(self): + self.base1.annotations = {'val1': 1, 'val2': 2.2, 'val3': True} + self.base2.annotations = {'val3': False, 'val4': [4, 4.4], + 'val5': True} + self.base3.annotations = {'val5': 3.4, 'val6': [4, 4.4], + 'val7': True} + self.base1.merge(self.base2, self.base3) + self.assertEqual(self.base1.annotations, + {'val1': 1, + 'val2': 2.2, + 'val3': 'MERGE CONFLICT', + 'val4': [4, 4.4], + 'val5': 'MERGE CONFLICT', + 'val6': [4, 4.4], + 'val7': True}) + class TestBaseNeoCoreTypes(unittest.TestCase): ''' From 712eb80f6d0e1cef224edce06a58af659d8910cf Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 13:31:18 +0100 Subject: [PATCH 06/15] Added unittests for merging multiple spiketrains --- neo/test/coretest/test_spiketrain.py | 123 +++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 7b949bbf4..0df9245b4 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -1150,6 +1150,47 @@ def test_merge_typical(self): np.array([1, 101, 2, 102, 3, 103, 4, 104, 5, 105, 6, 106])) self.assertIsInstance(result.array_annotations, ArrayDict) + def test_merge_multiple(self): + self.train1.waveforms = None + + train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond) + train3.segment = self.train1.segment + train3.array_annotate(index=np.arange(301, 307)) + + train4 = self.train1.duplicate_with_new_data(self.train1.times / 2) + train4.segment = self.train1.segment + train4.array_annotate(index=np.arange(401, 407)) + + # Array annotations merge warning was already tested, can be ignored now + with warnings.catch_warnings(record=True) as w: + result = self.train1.merge(train3, train4) + self.assertEqual(len(w), 1) + self.assertTrue("array annotations" in str(w[0].message)) + + assert_neo_object_is_compliant(result) + + self.assertEqual(len(result.shape), 1) + self.assertEqual(result.shape[0], sum(len(st) + for st in (self.train1, train3, train4))) + + self.assertEqual(self.train1.sampling_rate, result.sampling_rate) + + time_unit = result.units + + expected = np.concatenate((self.train1.rescale(time_unit).times, + train3.rescale(time_unit).times, train4.rescale(time_unit).times)) + expected *= time_unit + sorting = np.argsort(expected) + expected = expected[sorting] + np.testing.assert_array_equal(result.times, expected) + + # Make sure array annotations are merged correctly + self.assertTrue('label' not in result.array_annotations) + assert_arrays_equal(result.array_annotations['index'], + np.concatenate([st.array_annotations['index'] + for st in (self.train1, train3, train4)])[sorting]) + self.assertIsInstance(result.array_annotations, ArrayDict) + def test_merge_with_waveforms(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: @@ -1158,6 +1199,38 @@ def test_merge_with_waveforms(self): self.assertTrue("array annotations" in str(w[0].message)) assert_neo_object_is_compliant(result) + def test_merge_multiple_with_waveforms(self): + train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond) + train3.segment = self.train1.segment + train3.array_annotate(index=np.arange(301, 307)) + train3.waveforms = self.train1.waveforms / 10 + + train4 = self.train1.duplicate_with_new_data(self.train1.times / 2) + train4.segment = self.train1.segment + train4.array_annotate(index=np.arange(401, 407)) + train4.waveforms = self.train1.waveforms / 2 + + # Array annotations merge warning was already tested, can be ignored now + with warnings.catch_warnings(record=True) as w: + result = self.train1.merge(train3, train4) + self.assertEqual(len(w), 1) + self.assertTrue("array annotations" in str(w[0].message)) + + assert_neo_object_is_compliant(result) + self.assertEqual(len(result.shape), 1) + self.assertEqual(result.shape[0], sum(len(st) for st in (self.train1, train3, train4))) + + time_unit = result.units + + expected = np.concatenate((self.train1.rescale(time_unit).times, + train3.rescale(time_unit).times, train4.rescale(time_unit).times)) + sorting = np.argsort(expected) + + assert_arrays_equal(result.waveforms, + np.vstack([st.waveforms.rescale(self.train1.waveforms.units) + for st in (self.train1, train3, train4)])[sorting] * self.train1.waveforms.units + ) + def test_correct_shape(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: @@ -1237,6 +1310,56 @@ def test_incompatible_t_start(self): with self.assertRaises(MergeError): self.train2.merge(train3) + def test_merge_multiple_raise_merge_errors(self): + # different t_start + train3 = self.train1.duplicate_with_new_data(self.train1, t_start=-1 * pq.s) + train3.segment = self.train1.segment + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different t_stop + train3 = self.train1.duplicate_with_new_data(self.train1, t_stop=133 * pq.s) + train3.segment = self.train1.segment + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different segment + train3 = self.train1.duplicate_with_new_data(self.train1) + seg = Segment() + train3.segment = seg + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # missing waveforms + train3 = self.train1.duplicate_with_new_data(self.train1) + train3.waveforms = None + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different sampling rate + train3 = self.train1.duplicate_with_new_data(self.train1) + train3.sampling_rate = 1 * pq.s + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + + # different left sweep + train3 = self.train1.duplicate_with_new_data(self.train1) + train3.left_sweep = 1 * pq.s + with self.assertRaises(MergeError): + train3.merge(self.train2, self.train1) + with self.assertRaises(MergeError): + self.train2.merge(train3, self.train1) + class TestDuplicateWithNewData(unittest.TestCase): def setUp(self): From 77f99a129900fab4568a6203e27e09e6ef3d4f8c Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 13:31:34 +0100 Subject: [PATCH 07/15] Cleanup --- neo/core/spiketrain.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index e60ceea8e..0ac3d6b71 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -628,9 +628,6 @@ def merge(self, *others): If the attributes of the :class:`SpikeTrain` objects are not compatible, an Exception is raised. ''' - # TODO: unittests for merging multiple spiketrains - # TODO: maybe allow for lists of spiketrains as input? would be convenient in many use cases - # TODO: improve! better way to loop over self AND others? for other in others: if self.sampling_rate != other.sampling_rate: raise MergeError("Cannot merge, different sampling rates") From 2aa32f986b3a98e85bb220bacf3aa08a51f6f882 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 14:04:08 +0100 Subject: [PATCH 08/15] pep8 --- neo/core/spiketrain.py | 11 +++++++---- neo/test/coretest/test_base.py | 2 +- neo/test/coretest/test_spiketrain.py | 11 ++++++----- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 0ac3d6b71..cb881b6a4 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -678,7 +678,8 @@ def merge(self, *others): t_start=self.t_start, t_stop=self.t_stop, sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs) if all(wfs): - wfs_stack = np.vstack([st.waveforms.rescale(self.waveforms.units) for st in all_spiketrains]) + wfs_stack = np.vstack([st.waveforms.rescale(self.waveforms.units) + for st in all_spiketrains]) wfs_stack = wfs_stack[sorting] * self.waveforms.units train.waveforms = wfs_stack train.segment = self.segment @@ -710,7 +711,8 @@ def _merge_array_annotations(self, others, sorting=None): for key in keys: try: self_ann = copy.deepcopy(self.array_annotations[key]) - other_ann = np.concatenate([copy.deepcopy(other.array_annotations[key]) for other in others]) + other_ann = np.concatenate([copy.deepcopy(other.array_annotations[key]) + for other in others]) if isinstance(self_ann, pq.Quantity): other_ann.rescale(self_ann.units) arr_ann = np.concatenate([self_ann, other_ann]) * self_ann.units @@ -723,8 +725,9 @@ def _merge_array_annotations(self, others, sorting=None): omitted_keys_self.append(key) continue - omitted_keys_other = [key for key in np.unique([key for other in others for key in other.array_annotations]) if - key not in self.array_annotations] + omitted_keys_other = [key for key in np.unique([key for other in others + for key in other.array_annotations]) + if key not in self.array_annotations] if omitted_keys_self or omitted_keys_other: warnings.warn("The following array annotations were omitted, because they were only " diff --git a/neo/test/coretest/test_base.py b/neo/test/coretest/test_base.py index 49f1e5a3c..f32082a6a 100644 --- a/neo/test/coretest/test_base.py +++ b/neo/test/coretest/test_base.py @@ -196,7 +196,7 @@ def test_merge_multiple_annotations__dict(self): 'val4': [4, 4.4], 'val5': {1: {1: 1}, 2: 2}, 'val6': np.array([4, 5, 6]), 'val7': True} self.base3.annotations = {'val2': 2.2, 'val3': 'test3', - 'val4': [44], 'val5': {1: {2: 2}, 2: 2, 3:3}, + 'val4': [44], 'val5': {1: {2: 2}, 2: 2, 3: 3}, 'val6': np.array([8, 9, 10]), 'val8': False} ann1 = self.base1.annotations diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 0df9245b4..91a67512f 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -1178,7 +1178,8 @@ def test_merge_multiple(self): time_unit = result.units expected = np.concatenate((self.train1.rescale(time_unit).times, - train3.rescale(time_unit).times, train4.rescale(time_unit).times)) + train3.rescale(time_unit).times, + train4.rescale(time_unit).times)) expected *= time_unit sorting = np.argsort(expected) expected = expected[sorting] @@ -1223,13 +1224,13 @@ def test_merge_multiple_with_waveforms(self): time_unit = result.units expected = np.concatenate((self.train1.rescale(time_unit).times, - train3.rescale(time_unit).times, train4.rescale(time_unit).times)) + train3.rescale(time_unit).times, + train4.rescale(time_unit).times)) sorting = np.argsort(expected) - assert_arrays_equal(result.waveforms, + assert_arrays_equal(result.waveforms, self.train1.waveforms.units * np.vstack([st.waveforms.rescale(self.train1.waveforms.units) - for st in (self.train1, train3, train4)])[sorting] * self.train1.waveforms.units - ) + for st in (self.train1, train3, train4)])[sorting]) def test_correct_shape(self): # Array annotations merge warning was already tested, can be ignored now From 2ffbc0620bdac018d0cbcea72e6e1e8143c1a8f4 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Thu, 28 Feb 2019 14:35:56 +0100 Subject: [PATCH 09/15] wrong pep8 version --- neo/test/coretest/test_spiketrain.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 91a67512f..5600437dd 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -1228,9 +1228,10 @@ def test_merge_multiple_with_waveforms(self): train4.rescale(time_unit).times)) sorting = np.argsort(expected) - assert_arrays_equal(result.waveforms, self.train1.waveforms.units * + assert_arrays_equal(result.waveforms, np.vstack([st.waveforms.rescale(self.train1.waveforms.units) - for st in (self.train1, train3, train4)])[sorting]) + for st in (self.train1, train3, train4)])[sorting] + * self.train1.waveforms.units) def test_correct_shape(self): # Array annotations merge warning was already tested, can be ignored now From 488276e25cbd976ed38ed6ecb4463e05da1bbea5 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Fri, 1 Mar 2019 11:16:58 +0100 Subject: [PATCH 10/15] Fixed name, description and file_origin attributes --- neo/core/spiketrain.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index cb881b6a4..d5ee9fec9 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -665,12 +665,24 @@ def merge(self, *others): kwargs['array_annotations'] = self._merge_array_annotations(others, sorting=sorting) for name in ("name", "description", "file_origin"): - attr_self = getattr(self, name) - attr_other = getattr(other, name) - if attr_self == attr_other: - kwargs[name] = attr_self - else: - kwargs[name] = "merge(%s, %s)" % (attr_self, attr_other) + attr = getattr(self, name) + for other in others: + attr_other = getattr(other, name) + if attr is None and attr_other is None: + continue + elif attr is None or attr_other is None: + attr = str(attr) + attr_other = str(attr_other) + if attr_other not in attr: + attr += ', ' + attr_other + if 'merge' not in attr: + attr = 'merge(' + attr + if attr is None: + pass + elif 'merge' in attr: + attr += ')' + kwargs[name] = attr + merged_annotations = merge_annotations(self.annotations, other.annotations) kwargs.update(merged_annotations) From 380aa74cfa4a1efe88c13f375e34c42540b59ab2 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Sun, 3 Mar 2019 22:52:18 +0100 Subject: [PATCH 11/15] Ensured that the merging of the name, description and file_origin attributes always results in merge(unique list) or one single attribute if it is the same for all spiketrains. Made sure this still works if the involved spiketrains are already merged spiketrains. --- neo/core/spiketrain.py | 50 +++++++++++++++++++++--- neo/test/coretest/test_spiketrain.py | 57 ++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index cb881b6a4..92d0f3051 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -665,12 +665,50 @@ def merge(self, *others): kwargs['array_annotations'] = self._merge_array_annotations(others, sorting=sorting) for name in ("name", "description", "file_origin"): - attr_self = getattr(self, name) - attr_other = getattr(other, name) - if attr_self == attr_other: - kwargs[name] = attr_self - else: - kwargs[name] = "merge(%s, %s)" % (attr_self, attr_other) + attr = getattr(self, name) + + # check if self is already a merged spiketrain + # if it is, get rid of the bracket at the end to append more attributes + if attr is not None: + if attr.startswith('merge(') and attr.endswith(')'): + attr = attr[:-1] + + for other in others: + attr_other = getattr(other, name) + + # both attributes are None --> nothing to do + if attr is None and attr_other is None: + continue + + # one of the attributes is None --> convert to string in order to merge them + elif attr is None or attr_other is None: + attr = str(attr) + attr_other = str(attr_other) + + # check if the other spiketrain is already a merged spiketrain + # if it is, append all of its merged attributes that aren't already in attr + if attr_other.startswith('merge(') and attr_other.endswith(')'): + for subattr in attr_other[6:-1].split('; '): + if subattr not in attr: + attr += '; ' + subattr + if not attr.startswith('merge('): + attr = 'merge(' + attr + + # if the other attribute is not in the list --> append + # if attr doesn't already start with merge add merge( in the beginning + elif attr_other not in attr: + attr += '; ' + attr_other + if not attr.startswith('merge('): + attr = 'merge(' + attr + + # close the bracket of merge(...) if necessary + if attr is not None: + if attr.startswith('merge('): + attr += ')' + + # write attr into kwargs dict + kwargs[name] = attr + merged_annotations = merge_annotations(self.annotations, other.annotations) kwargs.update(merged_annotations) diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 5600437dd..a61683ac5 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -1280,6 +1280,63 @@ def test_rescaling_units(self): np.array([1, 2, 3, 4, 5, 6, 101, 102, 103, 104, 105, 106])) self.assertIsInstance(result.array_annotations, ArrayDict) + def test_name_file_origin_description(self): + self.train1.waveforms = None + self.train2.waveforms = None + self.train1.name = 'name1' + self.train1.description = 'desc1' + self.train1.file_origin = 'file1' + self.train2.name = 'name2' + self.train2.description = 'desc2' + self.train2.file_origin = 'file2' + + train3 = self.train1.duplicate_with_new_data(self.train1.times.magnitude * pq.microsecond) + train3.segment = self.train1.segment + train3.name = 'name3' + train3.description = 'desc3' + train3.file_origin = 'file3' + + train4 = self.train1.duplicate_with_new_data(self.train1.times / 2) + train4.segment = self.train1.segment + train4.name = 'name3' + train4.description = 'desc3' + train4.file_origin = 'file3' + + # merge two spiketrains with different attributes + merge1 = self.train1.merge(self.train2) + + self.assertEqual(merge1.name, 'merge(name1; name2)') + self.assertEqual(merge1.description, 'merge(desc1; desc2)') + self.assertEqual(merge1.file_origin, 'merge(file1; file2)') + + # merge a merged spiketrain with a regular one + merge2 = merge1.merge(train3) + + self.assertEqual(merge2.name, 'merge(name1; name2; name3)') + self.assertEqual(merge2.description, 'merge(desc1; desc2; desc3)') + self.assertEqual(merge2.file_origin, 'merge(file1; file2; file3)') + + # merge two merged spiketrains + merge3 = merge1.merge(merge2) + + self.assertEqual(merge3.name, 'merge(name1; name2; name3)') + self.assertEqual(merge3.description, 'merge(desc1; desc2; desc3)') + self.assertEqual(merge3.file_origin, 'merge(file1; file2; file3)') + + # merge two spiketrains with identical attributes + merge4 = train3.merge(train4) + + self.assertEqual(merge4.name, 'name3') + self.assertEqual(merge4.description, 'desc3') + self.assertEqual(merge4.file_origin, 'file3') + + # merge a reqular spiketrain with a merged spiketrain + merge5 = train3.merge(merge1) + + self.assertEqual(merge5.name, 'merge(name3; name1; name2)') + self.assertEqual(merge5.description, 'merge(desc3; desc1; desc2)') + self.assertEqual(merge5.file_origin, 'merge(file3; file1; file2)') + def test_sampling_rate(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: From e352665dac4c80e873153baf5b36bef6640f7e65 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Sun, 3 Mar 2019 23:27:27 +0100 Subject: [PATCH 12/15] Added the possibility to merge SpikeTrainProxy objects into regular SpikeTrain objects, added a unittest for this case --- neo/core/spiketrain.py | 24 ++++++++++---------- neo/test/coretest/test_spiketrain.py | 33 +++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 92d0f3051..af8d0ee27 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -20,6 +20,8 @@ # needed for python 3 compatibility from __future__ import absolute_import, division, print_function + +import neo import sys import copy @@ -629,6 +631,9 @@ def merge(self, *others): compatible, an Exception is raised. ''' for other in others: + if type(other) not in [SpikeTrain, neo.io.proxyobjects.SpikeTrainProxy]: + raise MergeError("Cannot merge, only SpikeTrain and SpikeTrainProxy objects" + "can be merged into a SpikeTrain.") if self.sampling_rate != other.sampling_rate: raise MergeError("Cannot merge, different sampling rates") if self.t_start != other.t_start: @@ -640,18 +645,13 @@ def merge(self, *others): if self.segment != other.segment: raise MergeError("Cannot merge these signals as they belong to" " different segments.") - if other.units != self.units: - other = other.rescale(self.units) - if hasattr(self, "lazy_shape"): - merged_lazy_shape = self.lazy_shape[0] - for other in others: - if hasattr(other, "lazy_shape"): - merged_lazy_shape[0] += other.lazy_shape[0] - else: - raise MergeError("Cannot merge a lazy object with a real" - " object.") + all_spiketrains = [self] - all_spiketrains.extend(others) + all_spiketrains.extend([st.rescale(self.units) if type(st) is SpikeTrain else + st.load(load_waveforms=self.waveforms + is not None).rescale(self.units) + for st in others]) + wfs = [st.waveforms is not None for st in all_spiketrains] if any(wfs) and not all(wfs): raise MergeError("Cannot merge signal with waveform and signal " @@ -724,8 +724,6 @@ def merge(self, *others): if train.segment is not None: self.segment.spiketrains.append(train) - if hasattr(self, "lazy_shape"): - train.lazy_shape = merged_lazy_shape return train def _merge_array_annotations(self, others, sorting=None): diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index a61683ac5..5f7877bd8 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -24,12 +24,15 @@ else: HAVE_IPYTHON = True +from neo.rawio.examplerawio import ExampleRawIO +from neo.io.proxyobjects import SpikeTrainProxy + from neo.core.spiketrain import (check_has_dimensions_time, SpikeTrain, _check_time_in_range, _new_spiketrain) from neo.core import Segment, Unit from neo.core.baseneo import MergeError from neo.test.tools import (assert_arrays_equal, assert_arrays_almost_equal, - assert_neo_object_is_compliant) + assert_neo_object_is_compliant, assert_same_attributes) from neo.test.generate_datasets import (get_fake_value, get_fake_values, fake_neo, TEST_ANNOTATIONS) @@ -1337,6 +1340,34 @@ def test_name_file_origin_description(self): self.assertEqual(merge5.description, 'merge(desc3; desc1; desc2)') self.assertEqual(merge5.file_origin, 'merge(file3; file1; file2)') + def test_merge_with_proxy(self): + self.train1.waveforms = None + self.train2.waveforms = None + + reader = ExampleRawIO(filename='my_filename.fake') + reader.parse_header() + + proxy_sptr = SpikeTrainProxy(rawio=reader, unit_index=0, + block_index=0, seg_index=0) + + # change all attributes that have to be the same in order to merge the spiketrains + proxy_sptr.segment = self.train1.segment + proxy_sptr.sampling_rate = self.train1.sampling_rate + proxy_sptr.left_sweep = self.train1.left_sweep + + self.train1.t_stop = proxy_sptr.t_stop + self.train2.t_stop = proxy_sptr.t_stop + + loaded_sptr = proxy_sptr.load(load_waveforms=False) + loaded_sptr.segment = self.train1.segment + + merge_proxy = self.train1.merge(self.train2, proxy_sptr) + merge_loaded = self.train1.merge(self.train2, loaded_sptr) + + assert_neo_object_is_compliant(merge_proxy) + + assert_same_attributes(merge_proxy, merge_loaded) + def test_sampling_rate(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: From 5547919c35877fb670ecd0d8ee66878ceee966c2 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Wed, 10 Apr 2019 11:05:19 +0200 Subject: [PATCH 13/15] Change type check to isinstance --- neo/core/spiketrain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index af8d0ee27..51e9ef2e2 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -631,7 +631,7 @@ def merge(self, *others): compatible, an Exception is raised. ''' for other in others: - if type(other) not in [SpikeTrain, neo.io.proxyobjects.SpikeTrainProxy]: + if not isinstance(other, (SpikeTrain, neo.io.proxyobjects.SpikeTrainProxy)): raise MergeError("Cannot merge, only SpikeTrain and SpikeTrainProxy objects" "can be merged into a SpikeTrain.") if self.sampling_rate != other.sampling_rate: From 7d9dad9db1fd6f816db2e4b8a7e6b670e83b3209 Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Mon, 16 Sep 2019 16:00:46 +0200 Subject: [PATCH 14/15] Remove ability to merge proxies into regular spiketrains --- neo/core/spiketrain.py | 12 ++++++------ neo/test/coretest/test_spiketrain.py | 28 ---------------------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 51e9ef2e2..95f71d095 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -631,8 +631,11 @@ def merge(self, *others): compatible, an Exception is raised. ''' for other in others: - if not isinstance(other, (SpikeTrain, neo.io.proxyobjects.SpikeTrainProxy)): - raise MergeError("Cannot merge, only SpikeTrain and SpikeTrainProxy objects" + if isinstance(other, neo.io.proxyobjects.SpikeTrainProxy): + raise MergeError("Cannot merge, SpikeTrainProxy objects cannot be merged" + "into regular SpikeTrain objects, please load them first.") + elif not isinstance(other, SpikeTrain): + raise MergeError("Cannot merge, only SpikeTrain" "can be merged into a SpikeTrain.") if self.sampling_rate != other.sampling_rate: raise MergeError("Cannot merge, different sampling rates") @@ -647,10 +650,7 @@ def merge(self, *others): " different segments.") all_spiketrains = [self] - all_spiketrains.extend([st.rescale(self.units) if type(st) is SpikeTrain else - st.load(load_waveforms=self.waveforms - is not None).rescale(self.units) - for st in others]) + all_spiketrains.extend([st.rescale(self.units) for st in others]) wfs = [st.waveforms is not None for st in all_spiketrains] if any(wfs) and not all(wfs): diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index 5f7877bd8..825cb9ff5 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -1340,34 +1340,6 @@ def test_name_file_origin_description(self): self.assertEqual(merge5.description, 'merge(desc3; desc1; desc2)') self.assertEqual(merge5.file_origin, 'merge(file3; file1; file2)') - def test_merge_with_proxy(self): - self.train1.waveforms = None - self.train2.waveforms = None - - reader = ExampleRawIO(filename='my_filename.fake') - reader.parse_header() - - proxy_sptr = SpikeTrainProxy(rawio=reader, unit_index=0, - block_index=0, seg_index=0) - - # change all attributes that have to be the same in order to merge the spiketrains - proxy_sptr.segment = self.train1.segment - proxy_sptr.sampling_rate = self.train1.sampling_rate - proxy_sptr.left_sweep = self.train1.left_sweep - - self.train1.t_stop = proxy_sptr.t_stop - self.train2.t_stop = proxy_sptr.t_stop - - loaded_sptr = proxy_sptr.load(load_waveforms=False) - loaded_sptr.segment = self.train1.segment - - merge_proxy = self.train1.merge(self.train2, proxy_sptr) - merge_loaded = self.train1.merge(self.train2, loaded_sptr) - - assert_neo_object_is_compliant(merge_proxy) - - assert_same_attributes(merge_proxy, merge_loaded) - def test_sampling_rate(self): # Array annotations merge warning was already tested, can be ignored now with warnings.catch_warnings(record=True) as w: From 9885e132fd6f032d8078c18164b3db2bc296223a Mon Sep 17 00:00:00 2001 From: kleinjohann Date: Fri, 7 Feb 2020 08:38:39 +0100 Subject: [PATCH 15/15] Fix not all spiketrains' annotations being merged --- neo/core/spiketrain.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 95f71d095..8aff822f7 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -709,7 +709,8 @@ def merge(self, *others): # write attr into kwargs dict kwargs[name] = attr - merged_annotations = merge_annotations(self.annotations, other.annotations) + merged_annotations = merge_annotations(*(st.annotations for st in + all_spiketrains)) kwargs.update(merged_annotations) train = SpikeTrain(stack, units=self.units, dtype=self.dtype, copy=False,