|
| 1 | +import argparse |
| 2 | +import csv |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import traceback |
| 6 | +from concurrent.futures import ProcessPoolExecutor |
| 7 | +from typing import List, Optional |
| 8 | + |
| 9 | +from vmrt_tesseract_utilities.database import (TranscriptionInput, |
| 10 | + TranscriptionOutput, |
| 11 | + get_database_session) |
| 12 | +from vmrt_tesseract_utilities.logging import stdout_logger |
| 13 | +from vmrt_tesseract_utilities.string_replacer import StringReplacer |
| 14 | + |
| 15 | +""" |
| 16 | +Replaces strings in scrubbed text files from a CSV or TSV. |
| 17 | +""" |
| 18 | + |
| 19 | + |
| 20 | +def read_target_strings(data_file: str, key_column: str) -> List[str]: |
| 21 | + """ |
| 22 | + Reads target strings from a CSV or TSV file. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + data_file : str |
| 27 | + The path to the CSV or TSV file. |
| 28 | + key_column : str |
| 29 | + The name of the column containing the target strings. |
| 30 | +
|
| 31 | + Returns |
| 32 | + ------- |
| 33 | + list |
| 34 | + A list of strings extracted from the specified column. |
| 35 | +
|
| 36 | + Raises |
| 37 | + ------ |
| 38 | + ValueError |
| 39 | + If the provided file is not a CSV or TSV, or if the specified key column is not found. |
| 40 | + FileNotFoundError |
| 41 | + If the specified data file does not exist. |
| 42 | + """ |
| 43 | + strings = [] |
| 44 | + # Detect file type based on extension |
| 45 | + if data_file.endswith(".csv"): |
| 46 | + delimiter = "," |
| 47 | + elif data_file.endswith(".tsv"): |
| 48 | + delimiter = "\t" |
| 49 | + else: |
| 50 | + raise ValueError("Invalid data file format. The data_file must be a CSV or TSV.") |
| 51 | + try: |
| 52 | + # Read the target strings from the specified column. |
| 53 | + with open(data_file, "r", newline="") as csvfile: |
| 54 | + reader = csv.DictReader(csvfile, delimiter=delimiter) |
| 55 | + if key_column not in reader.fieldnames: |
| 56 | + raise ValueError(f"Key column '{key_column}' not found in data file.") |
| 57 | + strings.extend([row[key_column] for row in reader]) |
| 58 | + return strings |
| 59 | + except FileNotFoundError as e: |
| 60 | + stdout_logger.error(f"Data file not found: {e}") |
| 61 | + raise |
| 62 | + except Exception as e: |
| 63 | + stdout_logger.error(f"An error occurred in read_target_strings: {e}\n{traceback.format_exc()}") |
| 64 | + raise |
| 65 | + |
| 66 | + |
| 67 | +def process_file(output_log, strings_to_replace: List[str], parsed_args: argparse.Namespace) -> Optional[object]: |
| 68 | + """ |
| 69 | + Processes a single file, replacing text and writing outputs. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + output_log : TranscriptionOutput |
| 74 | + An object containing file data. |
| 75 | + strings_to_replace : list |
| 76 | + The list of strings to search for. |
| 77 | + parsed_args : argparse.Namespace |
| 78 | + The parsed args. |
| 79 | +
|
| 80 | + Returns |
| 81 | + ------- |
| 82 | + TranscriptionOutput |
| 83 | + The updated file data object. |
| 84 | + """ |
| 85 | + input_file = None |
| 86 | + try: |
| 87 | + if hasattr(output_log, 'list_replacement_output_file') and output_log.list_replacement_output_file: |
| 88 | + input_file = output_log.list_replacement_output_file |
| 89 | + output_file = input_file |
| 90 | + else: |
| 91 | + input_file = output_log.ocr_output_file |
| 92 | + input_filename = os.path.basename(str(input_file)) |
| 93 | + filename_without_extension = os.path.splitext(input_filename)[0] |
| 94 | + output_dir = f"{parsed_args.output_dir}/list_replacement_output_file/{parsed_args.document_type}" |
| 95 | + os.makedirs(output_dir, exist_ok=True) |
| 96 | + output_file = f"{output_dir}/{filename_without_extension}.txt" |
| 97 | + |
| 98 | + with open(str(input_file), "r") as f: |
| 99 | + orig_text = f.read() |
| 100 | + |
| 101 | + replacer = StringReplacer(strings_to_replace, parsed_args.replacement_string) |
| 102 | + scrubbed_text = replacer.replace(orig_text) |
| 103 | + |
| 104 | + with open(output_file, "w") as outfile: |
| 105 | + outfile.write(scrubbed_text) |
| 106 | + stdout_logger.info(f"Scrubbed file written to {output_file}") |
| 107 | + |
| 108 | + output_log.list_replacement_output_file = output_file |
| 109 | + return output_log |
| 110 | + except Exception as e: |
| 111 | + stdout_logger.error(f"An error occurred while processing {input_file}: {e}\n{traceback.format_exc()}") |
| 112 | + return None |
| 113 | + |
| 114 | + |
| 115 | +def scrub_and_write_files(process_filepath_data: List[object], strings_to_replace: List[str], parsed_args: argparse.Namespace, use_multiprocessing: bool = True) -> None: |
| 116 | + """ |
| 117 | + Processes a list of files, replacing text and writing outputs in parallel. |
| 118 | +
|
| 119 | + Parameters |
| 120 | + ---------- |
| 121 | + process_filepath_data : list |
| 122 | + A list of file data objects. |
| 123 | + strings_to_replace : list |
| 124 | + The list of strings to search for. |
| 125 | + parsed_args : argparse.Namespace |
| 126 | + The parsed args. |
| 127 | + use_multiprocessing : bool |
| 128 | + Whether to use multiprocessing or not. |
| 129 | + """ |
| 130 | + session_maker = get_database_session(echo=parsed_args.debug_sql) |
| 131 | + batch_size = parsed_args.chunk_size |
| 132 | + success = True |
| 133 | + for i in range(0, len(process_filepath_data), batch_size): |
| 134 | + batch = process_filepath_data[i:i + batch_size] |
| 135 | + with session_maker.begin() as session: |
| 136 | + if use_multiprocessing: |
| 137 | + with ProcessPoolExecutor(max_workers=parsed_args.max_workers) as executor: |
| 138 | + results = executor.map(process_file_wrapper, batch, [strings_to_replace] * len(batch), [parsed_args] * len(batch)) |
| 139 | + for output_log in results: |
| 140 | + if output_log: |
| 141 | + session.add(output_log) |
| 142 | + else: |
| 143 | + success = False |
| 144 | + else: |
| 145 | + for output_log in batch: |
| 146 | + result = process_file_wrapper(output_log, strings_to_replace, parsed_args) |
| 147 | + if result: |
| 148 | + session.add(result) |
| 149 | + else: |
| 150 | + success = False |
| 151 | + if not success: |
| 152 | + # If any files failed to process, exit with an error code. |
| 153 | + sys.exit(1) |
| 154 | + |
| 155 | + |
| 156 | +def process_file_wrapper(output_log: object, strings_to_replace: List[str], parsed_args: argparse.Namespace) -> Optional[object]: |
| 157 | + """ |
| 158 | + Wrapper function to create a new session for each process. |
| 159 | +
|
| 160 | + Parameters |
| 161 | + ---------- |
| 162 | + output_log : TranscriptionOutput |
| 163 | + An object containing file data. |
| 164 | + strings_to_replace : list |
| 165 | + The list of strings to search for. |
| 166 | + parsed_args : argparse.Namespace |
| 167 | + The parsed args. |
| 168 | +
|
| 169 | + Returns |
| 170 | + ------- |
| 171 | + TranscriptionOutput |
| 172 | + The updated file data object. |
| 173 | + """ |
| 174 | + session_maker = get_database_session(echo=parsed_args.debug_sql) |
| 175 | + with session_maker.begin() as session: |
| 176 | + session.add(output_log) |
| 177 | + return process_file(output_log, strings_to_replace, parsed_args) |
| 178 | + |
| 179 | + |
| 180 | +def get_files_to_process(args: argparse.Namespace) -> list: |
| 181 | + """ |
| 182 | + Gets a list of input files to process. |
| 183 | +
|
| 184 | + Parameters |
| 185 | + ---------- |
| 186 | + args: argparse.Namespace |
| 187 | + The parsed args. |
| 188 | +
|
| 189 | + Returns |
| 190 | + ------- |
| 191 | + results: list |
| 192 | + The list of input files. |
| 193 | + """ |
| 194 | + sessionmaker = get_database_session(echo=args.debug_sql) |
| 195 | + with sessionmaker.begin() as session: |
| 196 | + query = (session.query(TranscriptionOutput) |
| 197 | + .outerjoin(TranscriptionInput.assets) |
| 198 | + .where(TranscriptionInput.document_type == args.document_type) |
| 199 | + .where(TranscriptionOutput.ocr_output_file != None) # noqa: E711 |
| 200 | + .limit(args.chunk_size) |
| 201 | + .offset(args.offset)) |
| 202 | + return query.all() |
| 203 | + |
| 204 | + |
| 205 | +def parse_args() -> argparse.Namespace: |
| 206 | + """ |
| 207 | + Parses the required args. |
| 208 | +
|
| 209 | + Returns |
| 210 | + ------- |
| 211 | + args : argparse.Namespace |
| 212 | + The parsed args. |
| 213 | + """ |
| 214 | + parser = argparse.ArgumentParser(description="Replace strings in scrubbed text files from a CSV or TSV.") |
| 215 | + parser.add_argument("data_file", help="Path to the CSV or TSV file containing the strings.") |
| 216 | + parser.add_argument("key_column", help="Name of the column in the CSV/TSV containing the keys.") |
| 217 | + parser.add_argument("replacement_string", help="The string to replace the keys with.") |
| 218 | + parser.add_argument("output_dir", help="Path to the output directory.") |
| 219 | + parser.add_argument("--document_type", type=str, default="document", help="The document type we want to produce, document, page or block.") |
| 220 | + parser.add_argument("--chunk_size", type=int, default=1000, help="The number of records to process.") |
| 221 | + parser.add_argument("--offset", type=int, default=0, help="The number of records to skip before beginning processing.") |
| 222 | + parser.add_argument("--debug-sql", action="store_true", help="Enable SQL debugging") |
| 223 | + parser.add_argument("--no-multiprocessing", action="store_true", help="Disable multiprocessing for debugging") |
| 224 | + parser.add_argument("--max-workers", type=int, default=4, help="Maximum number of worker processes for multiprocessing") |
| 225 | + return parser.parse_args() |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + args = parse_args() |
| 230 | + try: |
| 231 | + results = get_files_to_process(args) |
| 232 | + target_strings = read_target_strings(args.data_file, args.key_column) |
| 233 | + scrub_and_write_files(results, target_strings, args, not args.no_multiprocessing) |
| 234 | + except Exception as e: |
| 235 | + stdout_logger.error(f"Error in main execution: {e}\n{traceback.format_exc()}") |
| 236 | + # If an error occurs, exit with an error code. |
| 237 | + sys.exit(1) |
0 commit comments