Skip to content

Commit 7ee2d20

Browse files
committed
add correct HTTP query handler
1 parent 181b454 commit 7ee2d20

File tree

3 files changed

+99
-11
lines changed

3 files changed

+99
-11
lines changed

main_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,44 @@ func TestServe(t *testing.T) {
102102
},
103103
startTLS,
104104
},
105+
{
106+
"https cache with mix query source",
107+
"testdata/https.cache.yml",
108+
func(t *testing.T) {
109+
// do request which response must be cached
110+
queryURLParam := "SELECT * FROM system.numbers"
111+
queryBody := "LIMIT 10"
112+
expectedQuery := queryURLParam + "\n" + queryBody
113+
buf := bytes.NewBufferString(queryBody)
114+
req, err := http.NewRequest("GET", "https://127.0.0.1:8443?query="+url.QueryEscape(queryURLParam), buf)
115+
checkErr(t, err)
116+
req.SetBasicAuth("default", "qwerty")
117+
resp, err := tlsClient.Do(req)
118+
checkErr(t, err)
119+
if resp.StatusCode != http.StatusOK {
120+
t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusOK)
121+
}
122+
resp.Body.Close()
123+
124+
// check cached response
125+
key := &cache.Key{
126+
Query: []byte(expectedQuery),
127+
AcceptEncoding: "gzip",
128+
}
129+
path := fmt.Sprintf("%s/cache/%s", testDir, key.String())
130+
if _, err := os.Stat(path); err != nil {
131+
t.Fatalf("err while getting file %q info: %s", path, err)
132+
}
133+
rw := httptest.NewRecorder()
134+
cc := proxy.caches["https_cache"]
135+
if err := cc.WriteTo(rw, key); err != nil {
136+
t.Fatalf("unexpected error while writing reposnse from cache: %s", err)
137+
}
138+
expected := "Ok.\n"
139+
checkResponse(t, rw.Body, expected)
140+
},
141+
startTLS,
142+
},
105143
{
106144
"bad https cache",
107145
"testdata/https.cache.yml",

utils.go

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ func getAuth(req *http.Request) (string, string) {
4747
//
4848
// getQuerySnippet must be called only for error reporting.
4949
func getQuerySnippet(req *http.Request) string {
50-
if req.Method == http.MethodGet && req.URL.Query().Get("query") != "" {
51-
return req.URL.Query().Get("query")
50+
var query string
51+
52+
if req.URL.Query().Get("query") != "" {
53+
query = req.URL.Query().Get("query")
5254
}
5355

5456
crc, ok := req.Body.(*cachedReadCloser)
@@ -66,26 +68,44 @@ func getQuerySnippet(req *http.Request) string {
6668

6769
u := getDecompressor(req)
6870
if u == nil {
69-
return data
71+
if len(query) != 0 && len(data) != 0 {
72+
query += "\n"
73+
}
74+
75+
return query + data
7076
}
7177
bs := bytes.NewBufferString(data)
7278
b, err := u.decompress(bs)
7379
if err == nil {
74-
return string(b)
80+
if len(query) != 0 && len(b) != 0 {
81+
query += "\n"
82+
}
83+
84+
return query + string(b)
7585
}
7686
// It is better to return partially decompressed data instead of an empty string.
7787
if len(b) > 0 {
78-
return string(b)
88+
if len(query) != 0 && len(b) != 0 {
89+
query += "\n"
90+
}
91+
92+
return query + string(b)
7993
}
94+
8095
// The data failed to be decompressed. Return compressed data
8196
// instead of an empty string.
82-
return data
97+
if len(query) != 0 && len(data) != 0 {
98+
query += "\n"
99+
}
100+
101+
return query + data
83102
}
84103

85104
// getFullQuery returns full query from req.
86105
func getFullQuery(req *http.Request) ([]byte, error) {
87-
if req.Method == http.MethodGet && req.URL.Query().Get("query") != "" {
88-
return []byte(req.URL.Query().Get("query")), nil
106+
var result bytes.Buffer
107+
if req.URL.Query().Get("query") != "" {
108+
result.WriteString(req.URL.Query().Get("query"))
89109
}
90110
data, err := ioutil.ReadAll(req.Body)
91111
if err != nil {
@@ -95,14 +115,25 @@ func getFullQuery(req *http.Request) ([]byte, error) {
95115
req.Body = ioutil.NopCloser(bytes.NewBuffer(data))
96116
u := getDecompressor(req)
97117
if u == nil {
98-
return data, nil
118+
if result.Len() != 0 && len(data) != 0 {
119+
result.WriteByte('\n')
120+
}
121+
result.Write(data)
122+
123+
return result.Bytes(), nil
99124
}
100125
br := bytes.NewReader(data)
101126
b, err := u.decompress(br)
102127
if err != nil {
103128
return nil, fmt.Errorf("cannot uncompress query: %s", err)
104129
}
105-
return b, nil
130+
131+
if result.Len() != 0 && len(b) != 0 {
132+
result.WriteByte('\n')
133+
}
134+
result.Write(b)
135+
136+
return result.Bytes(), nil
106137
}
107138

108139
// canCacheQuery returns true if q can be cached.

utils_test.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func testCanCacheQuery(t *testing.T, q string, expected bool) {
6868
}
6969

7070
func TestGetQuerySnippetGET(t *testing.T) {
71-
req, err := http.NewRequest("GET", "", nil)
71+
req, err := http.NewRequest("GET", "", bytes.NewBuffer(nil))
7272
checkErr(t, err)
7373
params := make(url.Values)
7474
q := "SELECT column FROM table"
@@ -91,6 +91,25 @@ func TestGetQuerySnippetGETBody(t *testing.T) {
9191
}
9292
}
9393

94+
func TestGetQuerySnippetGETBothQueryAndBody(t *testing.T) {
95+
queryPart := "SELECT column"
96+
bodyPart := "FROM table"
97+
expectedQuery := "SELECT column\nFROM table"
98+
99+
body := bytes.NewBufferString(bodyPart)
100+
req, err := http.NewRequest("GET", "", body)
101+
checkErr(t, err)
102+
103+
params := make(url.Values)
104+
params.Set("query", queryPart)
105+
req.URL.RawQuery = params.Encode()
106+
107+
query := getQuerySnippet(req)
108+
if query != expectedQuery {
109+
t.Fatalf("got: %q; expected: %q", query, expectedQuery)
110+
}
111+
}
112+
94113
func TestGetQuerySnippetPOST(t *testing.T) {
95114
q := "SELECT column FROM table"
96115
body := bytes.NewBufferString(q)

0 commit comments

Comments
 (0)