Skip to content

Commit 1be96c1

Browse files
committed
fix: handle input array for get_tranco_rank
1 parent 6d0c367 commit 1be96c1

File tree

6 files changed

+49
-20
lines changed

6 files changed

+49
-20
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ test/python/__pycache__/
88
.Rhistory
99
*.log
1010
*.csv
11-
!test/data/tranco.csv
12-
!test/data/examples.csv
11+
!test/data/*.csv

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,20 +239,20 @@ You can use this function to get the ranking of a domain:
239239

240240
```sql
241241
D SELECT get_tranco_rank('microsoft.com') as rank;
242-
┌───────┐
243-
│ rank │
244-
int32
245-
├───────┤
246-
2
247-
└───────┘
242+
┌─────────
243+
rank
244+
varchar
245+
├─────────
246+
2
247+
└─────────
248248

249249
D SELECT get_tranco_rank('cloudflare.com') as rank;
250-
┌───────┐
251-
│ rank │
252-
int32
253-
├───────┤
254-
13
255-
└───────┘
250+
┌─────────
251+
rank
252+
varchar
253+
├─────────
254+
13
255+
└─────────
256256
```
257257

258258
### Get Extension Version

src/functions/get_tranco.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,28 @@ namespace duckdb
148148
throw std::runtime_error ("Tranco table not found. Download it first using `SELECT update_tranco(true);`");
149149
}
150150

151-
auto &domain_vector = args.data[0];
152-
auto domain = domain_vector.GetValue (0).ToString ();
151+
// Extract the input from the arguments
152+
auto &input_vector = args.data[0];
153+
auto result_data = FlatVector::GetData<string_t> (result);
153154

154-
auto query = "SELECT rank FROM tranco_list WHERE domain = '" + domain + "'";
155-
auto query_result = con.Query (query);
155+
for (idx_t i = 0; i < args.size (); i++)
156+
{
157+
auto input = input_vector.GetValue (i).ToString ();
158+
159+
try
160+
{
161+
auto query = "SELECT rank FROM tranco_list WHERE domain = '" + input + "'";
162+
163+
auto query_result = con.Query (query);
164+
auto rank = query_result->RowCount () > 0 ? query_result->GetValue (0, 0) : Value ();
156165

157-
result.SetValue (0, query_result->RowCount () > 0 ? query_result->GetValue (0, 0) : Value ());
166+
result_data[i] = StringVector::AddString (result, rank.ToString ());
167+
}
168+
catch (const std::exception &e)
169+
{
170+
result_data[i] = "Error extracting tranco rank: " + std::string (e.what ());
171+
}
172+
}
158173
}
159174
} // namespace netquack
160175
} // namespace duckdb

src/netquack_extension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ namespace duckdb
9595
auto get_tranco_rank_function = ScalarFunction (
9696
"get_tranco_rank",
9797
{ LogicalType::VARCHAR },
98-
LogicalType::INTEGER,
98+
LogicalType::VARCHAR,
9999
netquack::GetTrancoRankFunction);
100100
ExtensionUtil::RegisterFunction (instance, get_tranco_rank_function);
101101

test/data/examples_tranco.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
microsoft.com
2+
googleapis.com
3+
gstatic.com
4+
apple.com

test/sql/get_tranco_rank.test

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ require netquack
77
statement ok
88
CREATE TABLE tranco_list AS SELECT * FROM read_csv('test/data/tranco.csv', header=false, columns={'rank': 'INTEGER', 'domain': 'VARCHAR'});
99

10+
statement ok
11+
CREATE TABLE uri_list AS SELECT * FROM read_csv('test/data/examples_tranco.csv', header=false, columns={'uri': 'VARCHAR'});
12+
1013
query I
1114
SELECT COUNT(*) FROM tranco_list;
1215
----
@@ -26,3 +29,11 @@ query I
2629
SELECT get_tranco_rank('notfound.com');
2730
----
2831
NULL
32+
33+
query I
34+
SELECT get_tranco_rank(uri) from uri_list;
35+
----
36+
2
37+
10
38+
19
39+
7

0 commit comments

Comments
 (0)