Skip to content

Commit 9839b98

Browse files
authored
6 - Allow for stripping from static name lists (#10)
* 6 - initial work * 6 - add tests and cleanup. * 6 - make the replacer case insensitive. * 6 - throw an exit code of 1 if any of the processes fail and throw an exception if the data source is not a csv or tsv.
1 parent dd463fa commit 9839b98

File tree

9 files changed

+522
-58
lines changed

9 files changed

+522
-58
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ The scripts are easily run via the Dockerfile included in this repo.
2222
4. Get ready for the transcription process by running `python scripts/create_transcription_process.py /data`
2323
5. Use the `transcribe_pdfs.py` script to transcribe the files needed.
2424
- `python /workspace/scripts/transcribe_pdfs.py /workspace/output`
25-
6. Use the `pii_scrubber.py` script to remove PII from the text.
25+
6. Use the `/workspace/scripts/replace_strings.py` script to replace text in the files.
26+
- `python /workspace/scripts/replace_strings.py /path/to/dog_profile.tsv "subject_id" "<ID>" /workspace/output`
27+
7. Use the `pii_scrubber.py` script to remove PII from the text.
2628
- `python /workspace/scripts/scrubbers/pii_scrubber.py /workspace/output`
27-
7. Use the scripts in the `scripts/metadata_miners` directory to find data in the text.
29+
8. Use the scripts in the `scripts/metadata_miners` directory to find data in the text.
2830
- `python /workspace/scripts/metadata_miners/visit_date_miner.py /workspace/output --visit_date_tsv=/path/to/vet_visits.tsv --dog_profile_tsv=/path/to/dog_profile.tsv`

scripts/metadata_miners/visit_date_miner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def get_date_pairs_within_days(
4949
return result_pairs
5050

5151

52-
def get_dog_dates(parsed_args: argparse.Namespace, subject_id: str) -> Tuple[
53-
Optional[datetime.date], Optional[datetime.date]]:
52+
def get_dog_dates(parsed_args: argparse.Namespace, subject_id: str) -> Tuple[Optional[datetime.date], Optional[datetime.date]]:
5453
"""
5554
Retrieves the dog's birth and death dates from the TSV files.
5655
@@ -136,8 +135,7 @@ def update_existing_records(session: sqlalchemy.orm.session.Session, subject_id:
136135

137136

138137
def get_existing_date_pairs(session: sqlalchemy.orm.session.Session, subject_id: str, input_id: int,
139-
date_pairs: set[tuple[datetime, datetime]]) -> Set[
140-
Tuple[datetime.date, datetime.date]]:
138+
date_pairs: set[tuple[datetime, datetime]]) -> Set[Tuple[datetime.date, datetime.date]]:
141139
"""
142140
Retrieves existing date pairs from the database.
143141

scripts/replace_strings.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import argparse
2+
import csv
3+
import os
4+
import sys
5+
import traceback
6+
from concurrent.futures import ProcessPoolExecutor
7+
from typing import List, Optional
8+
9+
from vmrt_tesseract_utilities.database import (TranscriptionInput,
10+
TranscriptionOutput,
11+
get_database_session)
12+
from vmrt_tesseract_utilities.logging import stdout_logger
13+
from vmrt_tesseract_utilities.string_replacer import StringReplacer
14+
15+
"""
16+
Replaces strings in scrubbed text files from a CSV or TSV.
17+
"""
18+
19+
20+
def read_target_strings(data_file: str, key_column: str) -> List[str]:
21+
"""
22+
Reads target strings from a CSV or TSV file.
23+
24+
Parameters
25+
----------
26+
data_file : str
27+
The path to the CSV or TSV file.
28+
key_column : str
29+
The name of the column containing the target strings.
30+
31+
Returns
32+
-------
33+
list
34+
A list of strings extracted from the specified column.
35+
36+
Raises
37+
------
38+
ValueError
39+
If the provided file is not a CSV or TSV, or if the specified key column is not found.
40+
FileNotFoundError
41+
If the specified data file does not exist.
42+
"""
43+
strings = []
44+
# Detect file type based on extension
45+
if data_file.endswith(".csv"):
46+
delimiter = ","
47+
elif data_file.endswith(".tsv"):
48+
delimiter = "\t"
49+
else:
50+
raise ValueError("Invalid data file format. The data_file must be a CSV or TSV.")
51+
try:
52+
# Read the target strings from the specified column.
53+
with open(data_file, "r", newline="") as csvfile:
54+
reader = csv.DictReader(csvfile, delimiter=delimiter)
55+
if key_column not in reader.fieldnames:
56+
raise ValueError(f"Key column '{key_column}' not found in data file.")
57+
strings.extend([row[key_column] for row in reader])
58+
return strings
59+
except FileNotFoundError as e:
60+
stdout_logger.error(f"Data file not found: {e}")
61+
raise
62+
except Exception as e:
63+
stdout_logger.error(f"An error occurred in read_target_strings: {e}\n{traceback.format_exc()}")
64+
raise
65+
66+
67+
def process_file(output_log, strings_to_replace: List[str], parsed_args: argparse.Namespace) -> Optional[object]:
68+
"""
69+
Processes a single file, replacing text and writing outputs.
70+
71+
Parameters
72+
----------
73+
output_log : TranscriptionOutput
74+
An object containing file data.
75+
strings_to_replace : list
76+
The list of strings to search for.
77+
parsed_args : argparse.Namespace
78+
The parsed args.
79+
80+
Returns
81+
-------
82+
TranscriptionOutput
83+
The updated file data object.
84+
"""
85+
input_file = None
86+
try:
87+
if hasattr(output_log, 'list_replacement_output_file') and output_log.list_replacement_output_file:
88+
input_file = output_log.list_replacement_output_file
89+
output_file = input_file
90+
else:
91+
input_file = output_log.ocr_output_file
92+
input_filename = os.path.basename(str(input_file))
93+
filename_without_extension = os.path.splitext(input_filename)[0]
94+
output_dir = f"{parsed_args.output_dir}/list_replacement_output_file/{parsed_args.document_type}"
95+
os.makedirs(output_dir, exist_ok=True)
96+
output_file = f"{output_dir}/{filename_without_extension}.txt"
97+
98+
with open(str(input_file), "r") as f:
99+
orig_text = f.read()
100+
101+
replacer = StringReplacer(strings_to_replace, parsed_args.replacement_string)
102+
scrubbed_text = replacer.replace(orig_text)
103+
104+
with open(output_file, "w") as outfile:
105+
outfile.write(scrubbed_text)
106+
stdout_logger.info(f"Scrubbed file written to {output_file}")
107+
108+
output_log.list_replacement_output_file = output_file
109+
return output_log
110+
except Exception as e:
111+
stdout_logger.error(f"An error occurred while processing {input_file}: {e}\n{traceback.format_exc()}")
112+
return None
113+
114+
115+
def scrub_and_write_files(process_filepath_data: List[object], strings_to_replace: List[str], parsed_args: argparse.Namespace, use_multiprocessing: bool = True) -> None:
116+
"""
117+
Processes a list of files, replacing text and writing outputs in parallel.
118+
119+
Parameters
120+
----------
121+
process_filepath_data : list
122+
A list of file data objects.
123+
strings_to_replace : list
124+
The list of strings to search for.
125+
parsed_args : argparse.Namespace
126+
The parsed args.
127+
use_multiprocessing : bool
128+
Whether to use multiprocessing or not.
129+
"""
130+
session_maker = get_database_session(echo=parsed_args.debug_sql)
131+
batch_size = parsed_args.chunk_size
132+
success = True
133+
for i in range(0, len(process_filepath_data), batch_size):
134+
batch = process_filepath_data[i:i + batch_size]
135+
with session_maker.begin() as session:
136+
if use_multiprocessing:
137+
with ProcessPoolExecutor(max_workers=parsed_args.max_workers) as executor:
138+
results = executor.map(process_file_wrapper, batch, [strings_to_replace] * len(batch), [parsed_args] * len(batch))
139+
for output_log in results:
140+
if output_log:
141+
session.add(output_log)
142+
else:
143+
success = False
144+
else:
145+
for output_log in batch:
146+
result = process_file_wrapper(output_log, strings_to_replace, parsed_args)
147+
if result:
148+
session.add(result)
149+
else:
150+
success = False
151+
if not success:
152+
# If any files failed to process, exit with an error code.
153+
sys.exit(1)
154+
155+
156+
def process_file_wrapper(output_log: object, strings_to_replace: List[str], parsed_args: argparse.Namespace) -> Optional[object]:
157+
"""
158+
Wrapper function to create a new session for each process.
159+
160+
Parameters
161+
----------
162+
output_log : TranscriptionOutput
163+
An object containing file data.
164+
strings_to_replace : list
165+
The list of strings to search for.
166+
parsed_args : argparse.Namespace
167+
The parsed args.
168+
169+
Returns
170+
-------
171+
TranscriptionOutput
172+
The updated file data object.
173+
"""
174+
session_maker = get_database_session(echo=parsed_args.debug_sql)
175+
with session_maker.begin() as session:
176+
session.add(output_log)
177+
return process_file(output_log, strings_to_replace, parsed_args)
178+
179+
180+
def get_files_to_process(args: argparse.Namespace) -> list:
181+
"""
182+
Gets a list of input files to process.
183+
184+
Parameters
185+
----------
186+
args: argparse.Namespace
187+
The parsed args.
188+
189+
Returns
190+
-------
191+
results: list
192+
The list of input files.
193+
"""
194+
sessionmaker = get_database_session(echo=args.debug_sql)
195+
with sessionmaker.begin() as session:
196+
query = (session.query(TranscriptionOutput)
197+
.outerjoin(TranscriptionInput.assets)
198+
.where(TranscriptionInput.document_type == args.document_type)
199+
.where(TranscriptionOutput.ocr_output_file != None) # noqa: E711
200+
.limit(args.chunk_size)
201+
.offset(args.offset))
202+
return query.all()
203+
204+
205+
def parse_args() -> argparse.Namespace:
206+
"""
207+
Parses the required args.
208+
209+
Returns
210+
-------
211+
args : argparse.Namespace
212+
The parsed args.
213+
"""
214+
parser = argparse.ArgumentParser(description="Replace strings in scrubbed text files from a CSV or TSV.")
215+
parser.add_argument("data_file", help="Path to the CSV or TSV file containing the strings.")
216+
parser.add_argument("key_column", help="Name of the column in the CSV/TSV containing the keys.")
217+
parser.add_argument("replacement_string", help="The string to replace the keys with.")
218+
parser.add_argument("output_dir", help="Path to the output directory.")
219+
parser.add_argument("--document_type", type=str, default="document", help="The document type we want to produce, document, page or block.")
220+
parser.add_argument("--chunk_size", type=int, default=1000, help="The number of records to process.")
221+
parser.add_argument("--offset", type=int, default=0, help="The number of records to skip before beginning processing.")
222+
parser.add_argument("--debug-sql", action="store_true", help="Enable SQL debugging")
223+
parser.add_argument("--no-multiprocessing", action="store_true", help="Disable multiprocessing for debugging")
224+
parser.add_argument("--max-workers", type=int, default=4, help="Maximum number of worker processes for multiprocessing")
225+
return parser.parse_args()
226+
227+
228+
if __name__ == "__main__":
229+
args = parse_args()
230+
try:
231+
results = get_files_to_process(args)
232+
target_strings = read_target_strings(args.data_file, args.key_column)
233+
scrub_and_write_files(results, target_strings, args, not args.no_multiprocessing)
234+
except Exception as e:
235+
stdout_logger.error(f"Error in main execution: {e}\n{traceback.format_exc()}")
236+
# If an error occurs, exit with an error code.
237+
sys.exit(1)

scripts/scrubbers/pii_scrubber.py

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TranscriptionOutput,
1111
get_database_session)
1212
from vmrt_tesseract_utilities.logging import stdout_logger
13+
from vmrt_tesseract_utilities.scrubbing_utils import write_scrubbed_txt
1314

1415
"""
1516
Leverages presidio to attempt automatic PII stripping.
@@ -83,26 +84,6 @@ def scrub_pii(text: str, analyzer: AnalyzerEngine, threshold: float) -> tuple[st
8384
raise
8485

8586

86-
def write_scrubbed_txt(output_filename: str, anonymized_text: str) -> None:
87-
"""
88-
Writes the anonymized text to an output file.
89-
90-
Parameters
91-
----------
92-
output_filename : str
93-
The path to the file.
94-
anonymized_text : str
95-
The anonymized text.
96-
"""
97-
try:
98-
if anonymized_text:
99-
with open(output_filename, 'w') as f:
100-
f.write(anonymized_text)
101-
except Exception as e:
102-
stdout_logger.error(f'Error writing scrubbed output: {e}')
103-
raise
104-
105-
10687
def write_confidence_record(filename: str, filtered_results: list, original_text: str) -> None:
10788
"""
10889
Writes the filtered results to a JSON file.
@@ -133,31 +114,6 @@ def write_confidence_record(filename: str, filtered_results: list, original_text
133114
raise
134115

135116

136-
def get_output_strategy_from_path(file_path: str) -> str:
137-
"""
138-
Determines the type of path based on path segments.
139-
140-
Parameters
141-
----------
142-
file_path : str
143-
The file_path to check.
144-
145-
Returns
146-
-------
147-
str
148-
The extracted path type.
149-
"""
150-
parts = file_path.split(os.sep)
151-
if 'doc' in parts:
152-
return 'doc'
153-
elif 'page' in parts:
154-
return 'page'
155-
elif "unstructured_text" in parts:
156-
return parts[parts.index("unstructured_text") + 1]
157-
else:
158-
return 'page' # Default to page
159-
160-
161117
def process_files(process_filepath_data: list, analyzer: AnalyzerEngine,
162118
output_dir: str, threshold: float) -> None:
163119
"""
@@ -175,24 +131,38 @@ def process_files(process_filepath_data: list, analyzer: AnalyzerEngine,
175131
The confidence threshold for PII detection.
176132
"""
177133
sessionmaker = get_database_session(echo=args.debug_sql)
178-
with sessionmaker.begin() as session:
134+
with sessionmaker() as session:
179135
for output_log in process_filepath_data:
180-
with open(str(output_log.ocr_output_file), 'r') as f:
136+
session.add(output_log) # Ensure the instance is bound to the session
137+
# Use the replacement file if it exists, otherwise use the OCR output.
138+
if hasattr(output_log, 'list_replacement_output_file') and output_log.list_replacement_output_file:
139+
input_filepath = output_log.list_replacement_output_file
140+
else:
141+
input_filepath = output_log.ocr_output_file
142+
# Read the original text from the input file.
143+
with open(str(input_filepath), 'r') as f:
181144
orig_text = f.read()
145+
stdout_logger.info(f"Scrubbing {input_filepath}")
146+
# Scrub the PII from the text.
182147
scrubbed_text, result_output = scrub_pii(orig_text, analyzer, threshold)
183-
input_filename = os.path.basename(str(output_log.ocr_output_file))
148+
# Write the scrubbed text to a file.
149+
input_filename = os.path.basename(str(input_filepath))
184150
filename_without_extension = os.path.splitext(input_filename)[0]
185151
scrubbed_dir = f'{output_dir}/scrubbed_text/{args.document_type}/scrubbed_{args.document_type}'
186152
os.makedirs(scrubbed_dir, exist_ok=True) # Create directory if needed
187153
output_file = f'{scrubbed_dir}/{filename_without_extension}.txt'
188154
write_scrubbed_txt(output_file, scrubbed_text)
155+
stdout_logger.info(f"Scrubbed file written to {output_file}")
156+
# Write the scrubbed confidence data to a file.
189157
output_log.pii_scrubber_output_file = output_file
190158
confidence_dir = f'{output_dir}/scrubbed_text/{args.document_type}/scrubbed_confidence'
191159
os.makedirs(confidence_dir, exist_ok=True)
192160
confidence_file = f'{confidence_dir}/confidence-{filename_without_extension}.json'
193161
output_log.pii_scrubber_confidence_file = confidence_file
194162
write_confidence_record(confidence_file, result_output, orig_text)
163+
# Log the changes to the database.
195164
session.add(output_log)
165+
session.commit()
196166

197167

198168
def get_files_to_process(args: argparse.Namespace) -> list:
@@ -214,8 +184,8 @@ def get_files_to_process(args: argparse.Namespace) -> list:
214184
query = (session.query(TranscriptionOutput)
215185
.outerjoin(TranscriptionInput.assets)
216186
.where(TranscriptionInput.document_type == args.document_type)
217-
.where(TranscriptionOutput.ocr_output_file != None)
218-
.where(TranscriptionOutput.pii_scrubber_output_file == None)
187+
.where(TranscriptionOutput.ocr_output_file != None) # noqa: E711
188+
.where(TranscriptionOutput.pii_scrubber_output_file == None) # noqa: E711
219189
.limit(args.chunk_size)
220190
.offset(args.offset))
221191
return query.all()
@@ -252,4 +222,4 @@ def parse_args() -> argparse.Namespace:
252222
args = parse_args()
253223
db_output_logs = get_files_to_process(args)
254224
nlp_engine = create_nlp_engine(args.config)
255-
process_files(db_output_logs, nlp_engine, args.output_to, args.threshold)
225+
process_files(db_output_logs, nlp_engine, args.output_to, args.threshold)

0 commit comments

Comments
 (0)