Skip to content
This repository was archived by the owner on Jun 6, 2023. It is now read-only.

Commit 0949df4

Browse files
authored
Merge pull request #41 from SpikeInterface/msc_dumpable
Unified file creation/deletion
2 parents ca78597 + a9b47c5 commit 0949df4

File tree

3 files changed

+16
-27
lines changed

3 files changed

+16
-27
lines changed

spikecomparison/groundtruthstudy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,11 @@ def aggregate_dataframes(self, copy_into_folder=True, **karg_thresh):
157157
perfs = self.aggregate_performance_by_units()
158158

159159
dataframes['perf_by_units'] = perfs.reset_index()
160-
# dataframes['perf_pooled_with_average'] = perfs.reset_index().groupby(['rec_name', 'sorter_name']).mean().reset_index()
161160
dataframes['count_units'] = self.aggregate_count_units(**karg_thresh).reset_index()
162161

163162
if copy_into_folder:
164163
tables_folder = self.study_folder / 'tables'
165-
if not os.path.exists(tables_folder):
166-
os.makedirs(str(tables_folder))
164+
tables_folder.mkdir(parents=True, exist_ok=True)
167165

168166
for name, df in dataframes.items():
169167
df.to_csv(str(tables_folder / (name + '.csv')), sep='\t', index=False)
@@ -190,11 +188,11 @@ def get_units_snr(self, rec_name=None, **snr_kargs):
190188
rec_name = self._check_rec_name(rec_name)
191189

192190
metrics_folder = self.study_folder / 'metrics'
193-
if not (os.path.exists(metrics_folder)):
194-
os.makedirs(str(metrics_folder))
191+
metrics_folder.mkdir(parents=True, exist_ok=True)
192+
195193
filename = metrics_folder / ('SNR ' + rec_name + '.txt')
196194

197-
if os.path.exists(filename):
195+
if filename.is_file():
198196
snr = pd.read_csv(filename, sep='\t', index_col=None)
199197
snr = snr.set_index('gt_unit_id')
200198
else:

spikecomparison/multisortingcomparison.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def compute_subgraphs(self):
7979

8080
def dump(self, save_folder):
8181
save_folder = Path(save_folder)
82-
if not save_folder.is_dir():
83-
os.makedirs(str(save_folder))
82+
save_folder.mkdir(parents=True, exist_ok=True)
8483
filename = str(save_folder / 'multicomparison.gpickle')
8584
nx.write_gpickle(self.graph, filename)
8685
kwargs = {'delta_time': self.delta_time, 'sampling_frequency': self.sampling_frequency,

spikecomparison/studytools.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@
2424

2525
from spikesorters.sorterlist import sorter_dict
2626

27-
# TODO change this when sorters will be remove from toolkit
2827
from spikesorters import run_sorters, iter_output_folders, iter_sorting_output
2928
from .comparisontools import _perf_keys
3029
from .groundtruthcomparison import compare_sorter_to_ground_truth
3130

3231

33-
3432
def setup_comparison_study(study_folder, gt_dict):
3533
"""
3634
Based on a dict of (recording, sorting) create the study folder.
@@ -46,13 +44,13 @@ def setup_comparison_study(study_folder, gt_dict):
4644
"""
4745

4846
study_folder = Path(study_folder)
49-
assert not os.path.exists(study_folder), 'study_folder already exists'
47+
assert not study_folder.is_dir(), "'study_folder' already exists. Please remove it"
5048

51-
os.makedirs(str(study_folder))
52-
os.makedirs(str(study_folder / 'raw_files'))
53-
os.makedirs(str(study_folder / 'ground_truth'))
54-
os.makedirs(str(study_folder / 'sortings'))
55-
os.makedirs(str(study_folder / 'sortings/run_log'))
49+
study_folder.mkdir(parents=True, exist_ok=True)
50+
(study_folder / 'raw_files').mkdir(parents=True, exist_ok=True)
51+
(study_folder / 'ground_truth').mkdir(parents=True, exist_ok=True)
52+
(study_folder / 'sortings').mkdir(parents=True, exist_ok=True)
53+
(study_folder / 'sortings' / 'run_log').mkdir(parents=True, exist_ok=True)
5654

5755
for rec_name, (recording, sorting_gt) in gt_dict.items():
5856
# write recording as binary format + json + prb
@@ -123,7 +121,6 @@ def get_one_recording(study_folder, rec_name):
123121
info = json.load(f)
124122
rec = se.BinDatRecordingExtractor(raw_filename, info['sample_rate'], info['num_chan'],
125123
info['dtype'], time_axis=info['time_axis'])
126-
# rec = rec.load_probe_file(prb_filename)
127124
load_probe_file_inplace(rec, prb_filename)
128125

129126
return rec
@@ -250,8 +247,7 @@ def copy_sortings_to_npz(study_folder):
250247
sorter_folders = study_folder / 'sorter_folders'
251248
sorting_folders = study_folder / 'sortings'
252249

253-
if not os.path.exists(sorting_folders):
254-
os.makedirs(str(sorting_folders))
250+
sorting_folders.mkdir(parents=True, exist_ok=True)
255251

256252
for rec_name, sorter_name, output_folder in iter_output_folders(sorter_folders):
257253
SorterClass = sorter_dict[sorter_name]
@@ -261,13 +257,12 @@ def copy_sortings_to_npz(study_folder):
261257
sorting = SorterClass.get_result_from_folder(output_folder)
262258
se.NpzSortingExtractor.write_sorting(sorting, npz_filename)
263259
except:
264-
if os.path.exists(npz_filename):
265-
os.remove(npz_filename)
266-
if os.path.exists(output_folder / 'spikeinterface_log.json'):
260+
if npz_filename.is_file():
261+
npz_filename.unlink()
262+
if (output_folder / 'spikeinterface_log.json').is_file():
267263
shutil.copyfile(output_folder / 'spikeinterface_log.json', sorting_folders / 'run_log' / (fname + '.json'))
268264

269265

270-
271266
def iter_computed_names(study_folder):
272267
sorting_folder = Path(study_folder) / 'sortings'
273268
for filename in os.listdir(sorting_folder):
@@ -299,8 +294,7 @@ def collect_run_times(study_folder):
299294
log_folder = sorting_folders / 'run_log'
300295
tables_folder = study_folder / 'tables'
301296

302-
if not os.path.exists(tables_folder):
303-
os.makedirs(str(tables_folder))
297+
tables_folder.mkdir(parents=True, exist_ok=True)
304298

305299
run_times = []
306300
for filename in os.listdir(log_folder):
@@ -441,8 +435,6 @@ def aggregate_performances_table(study_folder, exhaustive_gt=False, **karg_thres
441435
return dataframes
442436

443437

444-
445-
446438
def load_probe_file_inplace(recording, probe_file):
447439
'''
448440
This is a locally modified version of spikeextractor.extraction_tools.load_probe_file.

0 commit comments

Comments
 (0)