Skip to content

Commit dc6c558

Browse files
committed
feat: add files clause
1 parent eab26bb commit dc6c558

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

snowflake_utils/models/table.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,15 +560,23 @@ def copy_custom(
560560
full_refresh: bool = False,
561561
sync_tags: bool = False,
562562
stage: str | None = None,
563+
files: list[str] | None = None,
563564
) -> None:
564565
column_names = ", ".join(column_definitions.keys())
565566
definitions = ", ".join(column_definitions.values())
566567

568+
files_clause = ""
569+
if files:
570+
# Format files list properly for Snowflake FILES clause
571+
files_str = "', '".join(files)
572+
files_clause = f"FILES = ('{files_str}')"
573+
567574
query = f"""
568575
COPY INTO {self.fqn} ({column_names})
569576
FROM
570577
(select {definitions} from @{stage}/{path})
571578
FILE_FORMAT = ( FORMAT_NAME ='{{file_format}}')
579+
{files_clause}
572580
"""
573581
return self._copy(
574582
query,
@@ -589,6 +597,7 @@ def merge_custom(
589597
replication_keys: list[str] | None = None,
590598
storage_integration: str | None = None,
591599
qualify: bool = False,
600+
files: list[str] | None = None,
592601
) -> None:
593602
def copy_callable(table: Table, sync_tags: bool) -> None:
594603
return table.copy_custom(
@@ -598,6 +607,7 @@ def copy_callable(table: Table, sync_tags: bool) -> None:
598607
file_format=file_format,
599608
full_refresh=True,
600609
sync_tags=sync_tags,
610+
files=files,
601611
)
602612

603613
return self._merge(copy_callable, primary_keys, replication_keys, qualify)

tests/test_models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,86 @@ def test_merge_custom(mock_connect, mock_merge, mock_copy, mock_drop):
566566
mock_drop.assert_called_once()
567567

568568

569+
@patch.object(Table, "_copy")
570+
def test_copy_custom_with_files(mock_copy) -> None:
571+
"""Test copy_custom method with specific files parameter."""
572+
# Setup mock to return expected result
573+
mock_copy.return_value = [("test_file.parquet", "LOADED")]
574+
575+
# Create a table instance
576+
table = Table(name="TEST_TABLE", schema_name="TEST_SCHEMA")
577+
578+
# Define column definitions
579+
column_definitions = {
580+
"id": "$1:id",
581+
"name": "$1:name",
582+
"last_name": "$1:last_name",
583+
}
584+
585+
# Test with files parameter
586+
result = table.copy_custom(
587+
column_definitions=column_definitions,
588+
path="s3://test-bucket/path",
589+
file_format=parquet_file_format,
590+
files=["test_file.parquet", "another_file.parquet"],
591+
)
592+
593+
# Verify the result
594+
assert result[0][1] == "LOADED"
595+
596+
# Verify the _copy method was called with the correct query containing FILES clause
597+
mock_copy.assert_called()
598+
call_args = mock_copy.call_args
599+
query = call_args[0][0] # First positional argument is the query
600+
assert "FILES = ('test_file.parquet', 'another_file.parquet')" in query
601+
602+
603+
@patch.object(Table, "_merge")
604+
def test_merge_custom_with_files(mock_merge) -> None:
605+
"""Test merge_custom method with specific files parameter."""
606+
# Create a table instance
607+
table = Table(name="TEST_TABLE", schema_name="TEST_SCHEMA")
608+
609+
# Define column definitions
610+
column_definitions = {
611+
"id": "$1:id",
612+
"name": "$1:name",
613+
"last_name": "$1:last_name",
614+
}
615+
616+
# Test merge_custom with files parameter
617+
table.merge_custom(
618+
column_definitions=column_definitions,
619+
path="s3://test-bucket/path",
620+
file_format=parquet_file_format,
621+
primary_keys=["id"],
622+
files=["test_file.parquet"],
623+
)
624+
625+
# Verify the _merge method was called
626+
mock_merge.assert_called_once()
627+
628+
# Get the copy_callable that was passed to _merge
629+
call_args = mock_merge.call_args
630+
copy_callable = call_args[0][0] # First positional argument is the copy_callable
631+
632+
# Now test the copy_callable to verify it passes the files parameter correctly
633+
with patch.object(Table, "_copy") as mock_copy:
634+
mock_copy.return_value = [("test_file.parquet", "LOADED")]
635+
636+
# Create a temporary table to test the copy_callable
637+
temp_table = Table(name="TEMP_TABLE", schema_name="TEST_SCHEMA")
638+
639+
# Call the copy_callable (this simulates what happens inside _merge)
640+
copy_callable(temp_table, sync_tags=False)
641+
642+
# Verify the _copy method was called with the correct query containing FILES clause
643+
mock_copy.assert_called()
644+
call_args = mock_copy.call_args
645+
query = call_args[0][0] # First positional argument is the query
646+
assert "FILES = ('test_file.parquet')" in query
647+
648+
569649
@patch.object(Table, "_copy")
570650
def test_copy_with_files(mock_copy) -> None:
571651
"""Test copy_into method with specific files parameter."""

0 commit comments

Comments
 (0)