Skip to content

Commit f1c405e

Browse files
authored
Merge pull request #27 from remoteoss/add-files-to-custom
fix: add files to custom
2 parents eab26bb + cd78b7f commit f1c405e

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

snowflake_utils/models/table.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def copy_into(
200200
stage: str | None = None,
201201
files: list[str] | None = None,
202202
) -> None:
203-
if stage:
204-
path = f"@{stage}/{path}"
205203
col_str = f"({', '.join(target_columns)})" if target_columns else ""
206204
files_clause = ""
207205
if files:
@@ -226,6 +224,7 @@ def copy_into(
226224
storage_integration,
227225
full_refresh,
228226
sync_tags,
227+
stage,
229228
)
230229
with connect() as connection:
231230
cursor = connection.cursor()
@@ -244,6 +243,7 @@ def copy_into(
244243
storage_integration,
245244
full_refresh,
246245
sync_tags,
246+
stage,
247247
)
248248

249249
def create_table(self, full_refresh: bool, execute_statement: callable) -> None:
@@ -560,15 +560,23 @@ def copy_custom(
560560
full_refresh: bool = False,
561561
sync_tags: bool = False,
562562
stage: str | None = None,
563+
files: list[str] | None = None,
563564
) -> None:
564565
column_names = ", ".join(column_definitions.keys())
565566
definitions = ", ".join(column_definitions.values())
566567

568+
files_clause = ""
569+
if files:
570+
# Format files list properly for Snowflake FILES clause
571+
files_str = "', '".join(files)
572+
files_clause = f"FILES = ('{files_str}')"
573+
567574
query = f"""
568575
COPY INTO {self.fqn} ({column_names})
569576
FROM
570577
(select {definitions} from @{stage}/{path})
571578
FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}')
579+
{files_clause}
572580
"""
573581
return self._copy(
574582
query,
@@ -589,6 +597,7 @@ def merge_custom(
589597
replication_keys: list[str] | None = None,
590598
storage_integration: str | None = None,
591599
qualify: bool = False,
600+
files: list[str] | None = None,
592601
) -> None:
593602
def copy_callable(table: Table, sync_tags: bool) -> None:
594603
return table.copy_custom(
@@ -598,6 +607,7 @@ def copy_callable(table: Table, sync_tags: bool) -> None:
598607
file_format=file_format,
599608
full_refresh=True,
600609
sync_tags=sync_tags,
610+
files=files,
601611
)
602612

603613
return self._merge(copy_callable, primary_keys, replication_keys, qualify)

tests/test_models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,86 @@ def test_merge_custom(mock_connect, mock_merge, mock_copy, mock_drop):
566566
mock_drop.assert_called_once()
567567

568568

569+
@patch.object(Table, "_copy")
570+
def test_copy_custom_with_files(mock_copy) -> None:
571+
"""Test copy_custom method with specific files parameter."""
572+
# Setup mock to return expected result
573+
mock_copy.return_value = [("test_file.parquet", "LOADED")]
574+
575+
# Create a table instance
576+
table = Table(name="TEST_TABLE", schema_name="TEST_SCHEMA")
577+
578+
# Define column definitions
579+
column_definitions = {
580+
"id": "$1:id",
581+
"name": "$1:name",
582+
"last_name": "$1:last_name",
583+
}
584+
585+
# Test with files parameter
586+
result = table.copy_custom(
587+
column_definitions=column_definitions,
588+
path="s3://test-bucket/path",
589+
file_format=parquet_file_format,
590+
files=["test_file.parquet", "another_file.parquet"],
591+
)
592+
593+
# Verify the result
594+
assert result[0][1] == "LOADED"
595+
596+
# Verify the _copy method was called with the correct query containing FILES clause
597+
mock_copy.assert_called()
598+
call_args = mock_copy.call_args
599+
query = call_args[0][0] # First positional argument is the query
600+
assert "FILES = ('test_file.parquet', 'another_file.parquet')" in query
601+
602+
603+
@patch.object(Table, "_merge")
604+
def test_merge_custom_with_files(mock_merge) -> None:
605+
"""Test merge_custom method with specific files parameter."""
606+
# Create a table instance
607+
table = Table(name="TEST_TABLE", schema_name="TEST_SCHEMA")
608+
609+
# Define column definitions
610+
column_definitions = {
611+
"id": "$1:id",
612+
"name": "$1:name",
613+
"last_name": "$1:last_name",
614+
}
615+
616+
# Test merge_custom with files parameter
617+
table.merge_custom(
618+
column_definitions=column_definitions,
619+
path="s3://test-bucket/path",
620+
file_format=parquet_file_format,
621+
primary_keys=["id"],
622+
files=["test_file.parquet"],
623+
)
624+
625+
# Verify the _merge method was called
626+
mock_merge.assert_called_once()
627+
628+
# Get the copy_callable that was passed to _merge
629+
call_args = mock_merge.call_args
630+
copy_callable = call_args[0][0] # First positional argument is the copy_callable
631+
632+
# Now test the copy_callable to verify it passes the files parameter correctly
633+
with patch.object(Table, "_copy") as mock_copy:
634+
mock_copy.return_value = [("test_file.parquet", "LOADED")]
635+
636+
# Create a temporary table to test the copy_callable
637+
temp_table = Table(name="TEMP_TABLE", schema_name="TEST_SCHEMA")
638+
639+
# Call the copy_callable (this simulates what happens inside _merge)
640+
copy_callable(temp_table, sync_tags=False)
641+
642+
# Verify the _copy method was called with the correct query containing FILES clause
643+
mock_copy.assert_called()
644+
call_args = mock_copy.call_args
645+
query = call_args[0][0] # First positional argument is the query
646+
assert "FILES = ('test_file.parquet')" in query
647+
648+
569649
@patch.object(Table, "_copy")
570650
def test_copy_with_files(mock_copy) -> None:
571651
"""Test copy_into method with specific files parameter."""

0 commit comments

Comments
 (0)