Skip to content

Commit 88ef41f

Browse files
committed
use Cow for FirewallCompareType::String
1 parent bcb36f1 commit 88ef41f

File tree

8 files changed

+112
-66
lines changed

8 files changed

+112
-66
lines changed

src/firewall/items/http_request.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::helpers::get_header;
88
use crate::proto::appguard::{AppGuardHttpRequest, AppGuardTcpInfo};
99
use rpn_predicate_interpreter::PredicateEvaluator;
1010
use serde::{Deserialize, Serialize};
11+
use std::borrow::Cow;
1112

1213
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1314
#[allow(clippy::enum_variant_names)]
@@ -42,30 +43,30 @@ impl HttpRequestField {
4243
item: &'a AppGuardHttpRequest,
4344
) -> Option<FirewallCompareType<'a>> {
4445
match self {
45-
HttpRequestField::HttpRequestUrl(v) => {
46-
Some(FirewallCompareType::String((&item.original_url, v)))
47-
}
48-
HttpRequestField::HttpRequestMethod(v) => {
49-
Some(FirewallCompareType::String((&item.method, v)))
50-
}
51-
HttpRequestField::HttpRequestQuery(HeaderVal(k, v)) => {
52-
get_header(&item.query, k).map(|query| FirewallCompareType::String((query, v)))
53-
}
46+
HttpRequestField::HttpRequestUrl(v) => Some(FirewallCompareType::String((
47+
&item.original_url,
48+
Cow::Borrowed(v),
49+
))),
50+
HttpRequestField::HttpRequestMethod(v) => Some(FirewallCompareType::String((
51+
&item.method,
52+
Cow::Borrowed(v),
53+
))),
54+
HttpRequestField::HttpRequestQuery(HeaderVal(k, v)) => get_header(&item.query, k)
55+
.map(|query| FirewallCompareType::String((query, Cow::Borrowed(v)))),
5456
HttpRequestField::HttpRequestCookie(v) => get_header(&item.headers, "Cookie")
55-
.map(|cookie| FirewallCompareType::String((cookie, v))),
56-
HttpRequestField::HttpRequestHeader(HeaderVal(k, v)) => {
57-
get_header(&item.headers, k).map(|header| FirewallCompareType::String((header, v)))
58-
}
57+
.map(|cookie| FirewallCompareType::String((cookie, Cow::Borrowed(v)))),
58+
HttpRequestField::HttpRequestHeader(HeaderVal(k, v)) => get_header(&item.headers, k)
59+
.map(|header| FirewallCompareType::String((header, Cow::Borrowed(v)))),
5960
HttpRequestField::HttpRequestBody(v) => item
6061
.body
6162
.as_ref()
62-
.map(|body| FirewallCompareType::String((body, v))),
63+
.map(|body| FirewallCompareType::String((body, Cow::Borrowed(v)))),
6364
HttpRequestField::HttpRequestBodyLen(v) => item
6465
.body
6566
.as_ref()
6667
.map(|body| FirewallCompareType::Usize((body.len(), v))),
6768
HttpRequestField::HttpRequestUserAgent(v) => get_header(&item.headers, "User-Agent")
68-
.map(|user_agent| FirewallCompareType::String((user_agent, v))),
69+
.map(|user_agent| FirewallCompareType::String((user_agent, Cow::Borrowed(v)))),
6970
}
7071
}
7172
}
@@ -76,7 +77,11 @@ impl PredicateEvaluator for AppGuardHttpRequest {
7677
type Reason = String;
7778
type Context = AppContext;
7879

79-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, context: &Self::Context) -> bool {
80+
async fn evaluate_predicate(
81+
&self,
82+
predicate: &Self::Predicate,
83+
context: &Self::Context,
84+
) -> bool {
8085
if predicate.direction == Some(FirewallRuleDirection::Out) {
8186
return false;
8287
}
@@ -93,7 +98,8 @@ impl PredicateEvaluator for AppGuardHttpRequest {
9398
direction: FirewallRuleDirection::In,
9499
},
95100
context,
96-
).await
101+
)
102+
.await
97103
}
98104
}
99105

src/firewall/items/http_response.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::helpers::get_header;
88
use crate::proto::appguard::{AppGuardHttpResponse, AppGuardTcpInfo};
99
use rpn_predicate_interpreter::PredicateEvaluator;
1010
use serde::{Deserialize, Serialize};
11+
use std::borrow::Cow;
1112

1213
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1314
#[allow(clippy::enum_variant_names)]
@@ -45,9 +46,8 @@ impl HttpResponseField {
4546
None
4647
}
4748
}
48-
HttpResponseField::HttpResponseHeader(HeaderVal(k, v)) => {
49-
get_header(&item.headers, k).map(|header| FirewallCompareType::String((header, v)))
50-
}
49+
HttpResponseField::HttpResponseHeader(HeaderVal(k, v)) => get_header(&item.headers, k)
50+
.map(|header| FirewallCompareType::String((header, Cow::Borrowed(v)))),
5151
}
5252
}
5353
}
@@ -58,7 +58,11 @@ impl PredicateEvaluator for AppGuardHttpResponse {
5858
type Reason = String;
5959
type Context = AppContext;
6060

61-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, context: &Self::Context) -> bool {
61+
async fn evaluate_predicate(
62+
&self,
63+
predicate: &Self::Predicate,
64+
context: &Self::Context,
65+
) -> bool {
6266
if predicate.direction == Some(FirewallRuleDirection::In) {
6367
return false;
6468
}
@@ -75,7 +79,8 @@ impl PredicateEvaluator for AppGuardHttpResponse {
7579
direction: FirewallRuleDirection::Out,
7680
},
7781
context,
78-
).await
82+
)
83+
.await
7984
}
8085
}
8186

src/firewall/items/ip_info.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::firewall::rules::{FirewallCompareType, FirewallRuleField, FirewallRul
33
use crate::proto::appguard::AppGuardIpInfo;
44
use rpn_predicate_interpreter::PredicateEvaluator;
55
use serde::{Deserialize, Serialize};
6+
use std::borrow::Cow;
67

78
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
89
#[serde(rename_all = "snake_case")]
@@ -39,35 +40,35 @@ impl IpInfoField {
3940
IpInfoField::Country(v) => item
4041
.country
4142
.as_ref()
42-
.map(|country| FirewallCompareType::String((country, v))),
43+
.map(|country| FirewallCompareType::String((country, Cow::Borrowed(v)))),
4344
IpInfoField::Asn(v) => item
4445
.asn
4546
.as_ref()
46-
.map(|asn| FirewallCompareType::String((asn, v))),
47+
.map(|asn| FirewallCompareType::String((asn, Cow::Borrowed(v)))),
4748
IpInfoField::Org(v) => item
4849
.org
4950
.as_ref()
50-
.map(|org| FirewallCompareType::String((org, v))),
51+
.map(|org| FirewallCompareType::String((org, Cow::Borrowed(v)))),
5152
IpInfoField::Continent(v) => item
5253
.continent_code
5354
.as_ref()
54-
.map(|continent| FirewallCompareType::String((continent, v))),
55+
.map(|continent| FirewallCompareType::String((continent, Cow::Borrowed(v)))),
5556
IpInfoField::City(v) => item
5657
.city
5758
.as_ref()
58-
.map(|city| FirewallCompareType::String((city, v))),
59+
.map(|city| FirewallCompareType::String((city, Cow::Borrowed(v)))),
5960
IpInfoField::Region(v) => item
6061
.region
6162
.as_ref()
62-
.map(|region| FirewallCompareType::String((region, v))),
63+
.map(|region| FirewallCompareType::String((region, Cow::Borrowed(v)))),
6364
IpInfoField::Postal(v) => item
6465
.postal
6566
.as_ref()
66-
.map(|postal| FirewallCompareType::String((postal, v))),
67+
.map(|postal| FirewallCompareType::String((postal, Cow::Borrowed(v)))),
6768
IpInfoField::Timezone(v) => item
6869
.timezone
6970
.as_ref()
70-
.map(|timezone| FirewallCompareType::String((timezone, v))),
71+
.map(|timezone| FirewallCompareType::String((timezone, Cow::Borrowed(v)))),
7172
}
7273
}
7374
}
@@ -78,7 +79,11 @@ impl<'a> PredicateEvaluator for &'a AppGuardIpInfo {
7879
type Reason = String;
7980
type Context = AppContext;
8081

81-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, _context: &Self::Context) -> bool {
82+
async fn evaluate_predicate(
83+
&self,
84+
predicate: &Self::Predicate,
85+
_context: &Self::Context,
86+
) -> bool {
8287
if let FirewallRuleField::IpInfo(f) = &predicate.rule.field {
8388
return predicate.rule.condition.compare(f.get_compare_fields(self));
8489
}

src/firewall/items/smtp_request.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::helpers::get_header;
88
use crate::proto::appguard::{AppGuardSmtpRequest, AppGuardTcpInfo};
99
use rpn_predicate_interpreter::PredicateEvaluator;
1010
use serde::{Deserialize, Serialize};
11+
use std::borrow::Cow;
1112

1213
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1314
#[allow(clippy::enum_variant_names)]
@@ -34,15 +35,14 @@ impl SmtpRequestField {
3435
item: &'a AppGuardSmtpRequest,
3536
) -> Option<FirewallCompareType<'a>> {
3637
match self {
37-
SmtpRequestField::SmtpRequestHeader(HeaderVal(k, v)) => {
38-
get_header(&item.headers, k).map(|header| FirewallCompareType::String((header, v)))
39-
}
38+
SmtpRequestField::SmtpRequestHeader(HeaderVal(k, v)) => get_header(&item.headers, k)
39+
.map(|header| FirewallCompareType::String((header, Cow::Borrowed(v)))),
4040
SmtpRequestField::SmtpRequestUserAgent(v) => get_header(&item.headers, "User-Agent")
41-
.map(|user_agent| FirewallCompareType::String((user_agent, v))),
41+
.map(|user_agent| FirewallCompareType::String((user_agent, Cow::Borrowed(v)))),
4242
SmtpRequestField::SmtpRequestBody(v) => item
4343
.body
4444
.as_ref()
45-
.map(|body| FirewallCompareType::String((body, v))),
45+
.map(|body| FirewallCompareType::String((body, Cow::Borrowed(v)))),
4646
SmtpRequestField::SmtpRequestBodyLen(l) => item
4747
.body
4848
.as_ref()
@@ -57,7 +57,11 @@ impl PredicateEvaluator for AppGuardSmtpRequest {
5757
type Reason = String;
5858
type Context = AppContext;
5959

60-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, context: &Self::Context) -> bool {
60+
async fn evaluate_predicate(
61+
&self,
62+
predicate: &Self::Predicate,
63+
context: &Self::Context,
64+
) -> bool {
6165
if predicate.direction == Some(FirewallRuleDirection::Out) {
6266
return false;
6367
}
@@ -74,7 +78,8 @@ impl PredicateEvaluator for AppGuardSmtpRequest {
7478
direction: FirewallRuleDirection::In,
7579
},
7680
context,
77-
).await
81+
)
82+
.await
7883
}
7984
}
8085

src/firewall/items/smtp_response.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ impl PredicateEvaluator for AppGuardSmtpResponse {
3939
type Reason = String;
4040
type Context = AppContext;
4141

42-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, context: &Self::Context) -> bool {
42+
async fn evaluate_predicate(
43+
&self,
44+
predicate: &Self::Predicate,
45+
context: &Self::Context,
46+
) -> bool {
4347
if predicate.direction == Some(FirewallRuleDirection::In) {
4448
return false;
4549
}
@@ -56,7 +60,8 @@ impl PredicateEvaluator for AppGuardSmtpResponse {
5660
direction: FirewallRuleDirection::Out,
5761
},
5862
context,
59-
).await
63+
)
64+
.await
6065
}
6166
}
6267

src/firewall/items/tcp_connection.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::firewall::rules::{
55
use crate::proto::appguard::AppGuardTcpConnection;
66
use rpn_predicate_interpreter::PredicateEvaluator;
77
use serde::{Deserialize, Serialize};
8+
use std::borrow::Cow;
89

910
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1011
#[serde(rename_all = "snake_case")]
@@ -66,9 +67,10 @@ impl TcpConnectionField {
6667
FirewallRuleDirection::Out => item.source_port,
6768
}
6869
.map(|p| FirewallCompareType::U32((p, v))),
69-
TcpConnectionField::Protocol(v) => {
70-
Some(FirewallCompareType::String((&item.protocol, v)))
71-
}
70+
TcpConnectionField::Protocol(v) => Some(FirewallCompareType::String((
71+
&item.protocol,
72+
Cow::Borrowed(v),
73+
))),
7274
}
7375
}
7476
}
@@ -79,13 +81,16 @@ impl<'a> PredicateEvaluator for &'a AppGuardTcpConnection {
7981
type Reason = String;
8082
type Context = AppContext;
8183

82-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, context: &Self::Context) -> bool {
84+
async fn evaluate_predicate(
85+
&self,
86+
predicate: &Self::Predicate,
87+
context: &Self::Context,
88+
) -> bool {
8389
if let FirewallRuleField::TcpConnection(f) = &predicate.rule.field {
84-
return predicate.rule.condition.compare(f.get_compare_fields(
85-
self,
86-
&predicate.direction,
87-
context,
88-
).await);
90+
return predicate.rule.condition.compare(
91+
f.get_compare_fields(self, &predicate.direction, context)
92+
.await,
93+
);
8994
}
9095
false
9196
}
@@ -101,19 +106,25 @@ impl<'a> PredicateEvaluator for &'a AppGuardTcpConnection {
101106

102107
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
103108
#[serde(untagged)]
104-
enum IpAlias {
109+
pub enum IpAlias {
105110
Name(String),
106111
Addresses(Vec<String>),
107112
}
108113

109114
impl IpAlias {
110-
async fn to_ips(&self, context: &AppContext) -> Option<&Vec<String>> {
115+
async fn to_ips(&self, context: &AppContext) -> Option<Cow<'_, Vec<String>>> {
111116
match self {
112117
IpAlias::Name(name) => {
113118
let token = context.root_token_provider.get().await.ok()?.jwt.clone();
114-
context.datastore.clone().get_ip_alias(token, name).await.ok()
119+
context
120+
.datastore
121+
.clone()
122+
.get_ip_alias(token, name)
123+
.await
124+
.ok()
125+
.map(|a| Cow::Owned(a))
115126
}
116-
IpAlias::Addresses(addresses) => Some(addresses),
127+
IpAlias::Addresses(addresses) => Some(Cow::Borrowed(addresses)),
117128
}
118129
}
119130
}

src/firewall/items/tcp_info.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,26 @@ impl<'a> PredicateEvaluator for &'a AppGuardTcpInfo {
99
type Reason = String;
1010
type Context = AppContext;
1111

12-
async fn evaluate_predicate(&self, predicate: &Self::Predicate, context: &Self::Context) -> bool {
12+
async fn evaluate_predicate(
13+
&self,
14+
predicate: &Self::Predicate,
15+
context: &Self::Context,
16+
) -> bool {
1317
match &predicate.rule.field {
14-
FirewallRuleField::TcpConnection(_) => self
15-
.connection
16-
.as_ref()
17-
.unwrap_or(&AppGuardTcpConnection::default())
18-
.evaluate_predicate(predicate, context).await,
19-
FirewallRuleField::IpInfo(_) => self
20-
.ip_info
21-
.as_ref()
22-
.unwrap_or(&AppGuardIpInfo::default())
23-
.evaluate_predicate(predicate, context).await,
18+
FirewallRuleField::TcpConnection(_) => {
19+
self.connection
20+
.as_ref()
21+
.unwrap_or(&AppGuardTcpConnection::default())
22+
.evaluate_predicate(predicate, context)
23+
.await
24+
}
25+
FirewallRuleField::IpInfo(_) => {
26+
self.ip_info
27+
.as_ref()
28+
.unwrap_or(&AppGuardIpInfo::default())
29+
.evaluate_predicate(predicate, context)
30+
.await
31+
}
2432
_ => false,
2533
}
2634
}

src/firewall/rules.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::firewall::items::smtp_request::SmtpRequestField;
88
use crate::firewall::items::smtp_response::SmtpResponseField;
99
use crate::firewall::items::tcp_connection::TcpConnectionField;
1010
use crate::proto::appguard_commands::FirewallPolicy;
11+
use std::borrow::Cow;
1112

1213
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1314
#[serde(rename_all = "snake_case")]
@@ -85,7 +86,7 @@ impl FirewallRuleCondition {
8586
pub fn compare(&self, firewall_compare_type: Option<FirewallCompareType>) -> bool {
8687
if let Some(fields) = firewall_compare_type {
8788
return match fields {
88-
FirewallCompareType::String((l, r)) => self.compare_vec(l, r),
89+
FirewallCompareType::String((l, r)) => self.compare_vec(l, &r),
8990
FirewallCompareType::Usize((l, r)) => self.compare_vec(&l, r),
9091
FirewallCompareType::U32((l, r)) => self.compare_vec(&l, r),
9192
FirewallCompareType::U64((l, r)) => self.compare_vec(&l, r),
@@ -130,7 +131,7 @@ impl FirewallRuleCondition {
130131

131132
#[derive(Debug, PartialEq)]
132133
pub enum FirewallCompareType<'a> {
133-
String((&'a String, &'a Vec<String>)),
134+
String((&'a String, Cow<'a, Vec<String>>)),
134135
Usize((usize, &'a Vec<usize>)),
135136
U32((u32, &'a Vec<u32>)),
136137
U64((u64, &'a Vec<u64>)),

0 commit comments

Comments
 (0)