Skip to content

Commit 19eb3da

Browse files
committed
feat: add ipcalc function
1 parent 3bce2b1 commit 19eb3da

File tree

3 files changed

+408
-0
lines changed

3 files changed

+408
-0
lines changed

src/functions/ipcalc.cpp

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
#include "ipcalc.hpp"
2+
3+
#include <cmath>
4+
#include <iostream>
5+
#include <regex>
6+
#include <sstream>
7+
#include <stdexcept>
8+
#include <vector>
9+
10+
#include "../utils/utils.hpp"
11+
12+
namespace duckdb
13+
{
14+
namespace netquack
15+
{
16+
IPInfo IPCalculator::calculate (const std::string &ipWithMask)
17+
{
18+
// Validate input format
19+
if (!isValidInput (ipWithMask))
20+
{
21+
throw std::invalid_argument ("Invalid input format. Expected format: x.x.x.x[/x]");
22+
}
23+
24+
// Parse IP and subnet mask
25+
size_t slashPos = ipWithMask.find ('/');
26+
std::string ip = ipWithMask.substr (0, slashPos);
27+
28+
// Default to /32 if no subnet mask is provided
29+
int maskBits = 32;
30+
if (slashPos != std::string::npos)
31+
{
32+
std::string maskStr = ipWithMask.substr (slashPos + 1);
33+
try
34+
{
35+
maskBits = std::stoi (maskStr);
36+
}
37+
catch (const std::exception &)
38+
{
39+
throw std::invalid_argument ("Invalid subnet mask. Must be a number between 0 and 32.");
40+
}
41+
}
42+
43+
// Validate subnet mask
44+
if (maskBits < 0 || maskBits > 32)
45+
{
46+
throw std::invalid_argument ("Subnet mask must be between 0 and 32");
47+
}
48+
49+
// Validate IP address
50+
if (!isValidIP (ip))
51+
{
52+
throw std::invalid_argument ("Invalid IP address.");
53+
}
54+
55+
// Calculate network properties
56+
std::string subnetMask = getSubnetMask (maskBits);
57+
std::string wildcardMask = getWildcardMask (subnetMask);
58+
59+
IPInfo info;
60+
info.address = ip;
61+
info.netmask = subnetMask;
62+
info.wildcard = wildcardMask;
63+
info.hostsPerNet = getHostsPerNet (maskBits);
64+
info.ipClass = getIPClass (ip);
65+
66+
if (maskBits != 32)
67+
{
68+
info.network = getNetworkAddress (ip, subnetMask);
69+
info.broadcast = getBroadcastAddress (info.network, wildcardMask);
70+
info.hostMin = getHostMin (info.network);
71+
info.hostMax = getHostMax (info.broadcast);
72+
}
73+
else
74+
{
75+
// For /32, the IP itself is the only host
76+
info.network = ip;
77+
info.broadcast = ip;
78+
info.hostMin = ip;
79+
info.hostMax = ip;
80+
}
81+
82+
return info;
83+
}
84+
85+
bool IPCalculator::isValidInput (const std::string &input)
86+
{
87+
std::regex pattern (R"((\d{1,3}\.){3}\d{1,3}(/\d{1,2})?)");
88+
return std::regex_match (input, pattern);
89+
}
90+
91+
bool IPCalculator::isValidIP (const std::string &ip)
92+
{
93+
auto octets = parseIP (ip);
94+
for (int octet : octets)
95+
{
96+
if (octet < 0 || octet > 255)
97+
{
98+
return false;
99+
}
100+
}
101+
return true;
102+
}
103+
104+
std::vector<int> IPCalculator::parseIP (const std::string &ip)
105+
{
106+
std::vector<int> octets;
107+
std::stringstream ss (ip);
108+
std::string octet;
109+
while (std::getline (ss, octet, '.'))
110+
{
111+
octets.push_back (std::stoi (octet));
112+
}
113+
return octets;
114+
}
115+
116+
std::string IPCalculator::getSubnetMask (int maskBits)
117+
{
118+
uint32_t mask = 0xFFFFFFFF << (32 - maskBits);
119+
return intToIP (mask);
120+
}
121+
122+
std::string IPCalculator::getWildcardMask (const std::string &subnetMask)
123+
{
124+
auto mask = parseIP (subnetMask);
125+
std::string wildcard;
126+
for (int i = 0; i < 4; ++i)
127+
{
128+
wildcard += std::to_string (255 - mask[i]) + (i < 3 ? "." : "");
129+
}
130+
return wildcard;
131+
}
132+
133+
std::string IPCalculator::getNetworkAddress (const std::string &ip, const std::string &subnetMask)
134+
{
135+
auto ipOctets = parseIP (ip);
136+
auto maskOctets = parseIP (subnetMask);
137+
std::string network;
138+
for (int i = 0; i < 4; ++i)
139+
{
140+
network += std::to_string (ipOctets[i] & maskOctets[i]) + (i < 3 ? "." : "");
141+
}
142+
return network;
143+
}
144+
145+
std::string IPCalculator::getBroadcastAddress (const std::string &networkAddress, const std::string &wildcardMask)
146+
{
147+
auto networkOctets = parseIP (networkAddress);
148+
auto wildcardOctets = parseIP (wildcardMask);
149+
std::string broadcast;
150+
for (int i = 0; i < 4; ++i)
151+
{
152+
broadcast += std::to_string (networkOctets[i] | wildcardOctets[i]) + (i < 3 ? "." : "");
153+
}
154+
return broadcast;
155+
}
156+
157+
std::string IPCalculator::getHostMin (const std::string &networkAddress)
158+
{
159+
auto octets = parseIP (networkAddress);
160+
octets[3] += 1;
161+
return intToIP (octets);
162+
}
163+
164+
std::string IPCalculator::getHostMax (const std::string &broadcastAddress)
165+
{
166+
auto octets = parseIP (broadcastAddress);
167+
octets[3] -= 1;
168+
return intToIP (octets);
169+
}
170+
171+
int IPCalculator::getHostsPerNet (int maskBits)
172+
{
173+
if (maskBits == 32)
174+
return 1; // Special case for /32
175+
return (1 << (32 - maskBits)) - 2;
176+
}
177+
178+
std::string IPCalculator::getIPClass (const std::string &ip)
179+
{
180+
auto octets = parseIP (ip);
181+
int firstOctet = octets[0];
182+
183+
if (firstOctet >= 1 && firstOctet <= 126)
184+
return "A";
185+
if (firstOctet == 127)
186+
return "A, Loopback";
187+
if (firstOctet >= 128 && firstOctet <= 191)
188+
return "B";
189+
if (firstOctet >= 192 && firstOctet <= 223)
190+
return "C";
191+
if (firstOctet >= 224 && firstOctet <= 239)
192+
return "D";
193+
return "E";
194+
}
195+
196+
std::string IPCalculator::intToIP (uint32_t ip)
197+
{
198+
std::stringstream ss;
199+
ss << ((ip >> 24) & 0xFF) << "." << ((ip >> 16) & 0xFF) << "." << ((ip >> 8) & 0xFF) << "." << (ip & 0xFF);
200+
return ss.str ();
201+
}
202+
203+
std::string IPCalculator::intToIP (const std::vector<int> &octets)
204+
{
205+
std::stringstream ss;
206+
for (int i = 0; i < 4; ++i)
207+
{
208+
ss << octets[i] << (i < 3 ? "." : "");
209+
}
210+
return ss.str ();
211+
}
212+
213+
struct IPCalcData : public TableFunctionData
214+
{
215+
string ip;
216+
};
217+
218+
struct IPCalcLocalState : public LocalTableFunctionState
219+
{
220+
std::atomic_bool done{ false };
221+
};
222+
223+
unique_ptr<FunctionData> IPCalcFunc::Bind (ClientContext &context, TableFunctionBindInput &input, vector<LogicalType> &return_types, vector<string> &names)
224+
{
225+
auto bind_data = make_uniq<IPCalcData> ();
226+
bind_data->ip = StringValue::Get (input.inputs[0]);
227+
228+
// 0. address
229+
return_types.emplace_back (LogicalType::VARCHAR);
230+
names.emplace_back ("address");
231+
232+
// 1. netmask
233+
return_types.emplace_back (LogicalType::VARCHAR);
234+
names.emplace_back ("netmask");
235+
236+
// 2. wildcard
237+
return_types.emplace_back (LogicalType::VARCHAR);
238+
names.emplace_back ("wildcard");
239+
240+
// 3. network
241+
return_types.emplace_back (LogicalType::VARCHAR);
242+
names.emplace_back ("network");
243+
244+
// 4. hostMin
245+
return_types.emplace_back (LogicalType::VARCHAR);
246+
names.emplace_back ("hostMin");
247+
248+
// 5. hostMax
249+
return_types.emplace_back (LogicalType::VARCHAR);
250+
names.emplace_back ("hostMax");
251+
252+
// 6. broadcast
253+
return_types.emplace_back (LogicalType::VARCHAR);
254+
names.emplace_back ("broadcast");
255+
256+
// 7. hostsPerNet
257+
return_types.emplace_back (LogicalType::VARCHAR);
258+
names.emplace_back ("hostsPerNet");
259+
260+
// 8. ipClass
261+
return_types.emplace_back (LogicalType::VARCHAR);
262+
names.emplace_back ("ipClass");
263+
264+
return std::move (bind_data);
265+
}
266+
267+
unique_ptr<LocalTableFunctionState> IPCalcFunc::InitLocal (ExecutionContext &context, TableFunctionInitInput &input, GlobalTableFunctionState *global_state_p)
268+
{
269+
return make_uniq<IPCalcLocalState> ();
270+
}
271+
272+
unique_ptr<GlobalTableFunctionState> IPCalcFunc::InitGlobal (ClientContext &context, TableFunctionInitInput &input)
273+
{
274+
return nullptr;
275+
}
276+
277+
void IPCalcFunc::Scan (ClientContext &context, TableFunctionInput &data_p, DataChunk &output)
278+
{
279+
// Check done
280+
if (((IPCalcLocalState &)*data_p.local_state).done)
281+
{
282+
return;
283+
}
284+
285+
auto &data = data_p.bind_data->Cast<IPCalcData> ();
286+
287+
IPInfo info = IPCalculator::calculate (data.ip);
288+
289+
output.SetCardinality (1);
290+
output.data[0].SetValue (0, info.address);
291+
output.data[1].SetValue (0, info.netmask);
292+
output.data[2].SetValue (0, info.wildcard);
293+
output.data[3].SetValue (0, info.network);
294+
output.data[4].SetValue (0, info.hostMin);
295+
output.data[5].SetValue (0, info.hostMax);
296+
output.data[6].SetValue (0, info.broadcast);
297+
output.data[7].SetValue (0, info.hostsPerNet);
298+
output.data[8].SetValue (0, info.ipClass);
299+
// Set done
300+
auto &local_state = (IPCalcLocalState &)*data_p.local_state;
301+
local_state.done = true;
302+
}
303+
304+
// Function to extract the information from an IP address
305+
// void IPCalcFunction (DataChunk &args, ExpressionState &state, Vector &result)
306+
// {
307+
// // Extract the input from the arguments
308+
// auto &input_vector = args.data[0];
309+
// auto result_data = FlatVector::GetData<string_t> (result);
310+
311+
// for (idx_t i = 0; i < args.size (); i++)
312+
// {
313+
// auto input = input_vector.GetValue (i).ToString ();
314+
315+
// try
316+
// {
317+
// // Get IP information
318+
// IPInfo info = IPCalculator::calculate (input);
319+
320+
// // Format the result as a string
321+
// std::stringstream ss;
322+
// ss << "Address: " << info.address << ", "
323+
// << "Netmask: " << info.netmask << ", "
324+
// << "Wildcard: " << info.wildcard << ", "
325+
// << "Network: " << info.network << ", "
326+
// << "HostMin: " << info.hostMin << ", "
327+
// << "HostMax: " << info.hostMax << ", "
328+
// << "Broadcast: " << info.broadcast << ", "
329+
// << "Hosts/Net: " << info.hostsPerNet << ", "
330+
// << "Class: " << info.ipClass;
331+
332+
// result_data[i] = ss.str ();
333+
// }
334+
// catch (const std::exception &e)
335+
// {
336+
// result_data[i] = "Error ipcalc: " + std::string (e.what ());
337+
// }
338+
// }
339+
// }
340+
} // namespace netquack
341+
} // namespace duckdb

src/functions/ipcalc.hpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#pragma once
2+
3+
#include <regex>
4+
#include <stdexcept>
5+
#include <string>
6+
#include <vector>
7+
8+
#include "duckdb.hpp"
9+
10+
namespace duckdb
11+
{
12+
namespace netquack
13+
{
14+
struct IPInfo
15+
{
16+
std::string address;
17+
std::string netmask;
18+
std::string wildcard;
19+
std::string network;
20+
std::string hostMin;
21+
std::string hostMax;
22+
std::string broadcast;
23+
int hostsPerNet;
24+
std::string ipClass;
25+
};
26+
27+
class IPCalculator
28+
{
29+
public:
30+
static IPInfo calculate (const std::string &ipWithMask);
31+
32+
private:
33+
static bool isValidInput (const std::string &input);
34+
static bool isValidIP (const std::string &ip);
35+
static std::vector<int> parseIP (const std::string &ip);
36+
static std::string getSubnetMask (int maskBits);
37+
static std::string getWildcardMask (const std::string &subnetMask);
38+
static std::string getNetworkAddress (const std::string &ip, const std::string &subnetMask);
39+
static std::string getBroadcastAddress (const std::string &networkAddress, const std::string &wildcardMask);
40+
static std::string getHostMin (const std::string &networkAddress);
41+
static std::string getHostMax (const std::string &broadcastAddress);
42+
static int getHostsPerNet (int maskBits);
43+
static std::string getIPClass (const std::string &ip);
44+
static std::string intToIP (uint32_t ip);
45+
static std::string intToIP (const std::vector<int> &octets);
46+
};
47+
48+
struct IPCalcFunc
49+
{
50+
static unique_ptr<FunctionData> Bind (ClientContext &context, TableFunctionBindInput &input, vector<LogicalType> &return_types, vector<string> &names);
51+
static void Scan (ClientContext &context, TableFunctionInput &data_p, DataChunk &output);
52+
static unique_ptr<LocalTableFunctionState> InitLocal (ExecutionContext &context, TableFunctionInitInput &input, GlobalTableFunctionState *global_state_p);
53+
static unique_ptr<GlobalTableFunctionState> InitGlobal (ClientContext &context, TableFunctionInitInput &input);
54+
};
55+
} // namespace netquack
56+
} // namespace duckdb

0 commit comments

Comments
 (0)