Skip to content
This repository was archived by the owner on Jul 21, 2021. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"fmt"
"io"
"net"
"reflect"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -73,6 +74,7 @@ type Conn struct {
xid uint32
sessionTimeoutMs int32 // session timeout in milliseconds
passwd []byte
chroot string

dialer Dialer
hostProvider HostProvider
Expand Down Expand Up @@ -196,6 +198,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti
requests: make(map[int32]*request),
watchers: make(map[watchPathType][]chan Event),
passwd: emptyPassword,
chroot: "",
logger: DefaultLogger,
buf: make([]byte, bufferSize),

Expand Down Expand Up @@ -611,6 +614,19 @@ func (c *Conn) sendData(req *request) error {
return nil
}

if req != nil && req.pkt != nil && c.chroot != "" {
v := reflect.ValueOf(req.pkt)
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
v = v.Elem()
}
if v.Kind() == reflect.Struct {
field := v.FieldByName("Path")
if field.Kind() == reflect.String {
field.SetString(c.chroot + field.String())
}
}
}

n2, err := encodePacket(c.buf[4+n:], req.pkt)
if err != nil {
req.recvChan <- response{-1, err}
Expand Down Expand Up @@ -708,6 +724,9 @@ func (c *Conn) recvLoop(conn net.Conn) error {
if err != nil {
return err
}
if c.chroot != "" {
res.Path = strings.TrimPrefix(res.Path, c.chroot)
}
ev := Event{
Type: res.Type,
State: res.State,
Expand Down Expand Up @@ -760,6 +779,19 @@ func (c *Conn) recvLoop(conn net.Conn) error {
} else {
_, err = decodePacket(buf[16:blen], req.recvStruct)
}

if req != nil && req.recvStruct != nil && c.chroot != "" {
v := reflect.ValueOf(req.recvStruct)
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
v = v.Elem()
}
if v.Kind() == reflect.Struct {
field := v.FieldByName("Path")
if field.Kind() == reflect.String {
field.SetString(strings.TrimPrefix(field.String(), c.chroot))
}
}
}
if req.recvFunc != nil {
req.recvFunc(req, &res, err)
}
Expand Down Expand Up @@ -804,6 +836,16 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc
return r.zxid, r.err
}

func (c *Conn) Chroot(path string) error {
res := &existsResponse{}
_, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil)
if err != nil {
return err
}
c.chroot = path
return nil
}

func (c *Conn) AddAuth(scheme string, auth []byte) error {
_, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)

Expand Down
43 changes: 43 additions & 0 deletions zk/zk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,49 @@ func TestExpiringWatch(t *testing.T) {
}
}

func TestChroot(t *testing.T) {
ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "})
if err != nil {
t.Fatal(err)
}
defer ts.Stop()
zk, _, err := ts.ConnectAll()
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zk.Close()

path := "/gozk-test-chroot"
err = zk.Chroot(path)
if err == nil {
t.Fatal("Chroot expect error when path is not exist")
}
_, err = zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll))
if err != nil {
t.Fatalf("Chroot create path error, %s", err.Error())
}
err = zk.Chroot(path)
if err != nil {
t.Fatalf("Chroot error, %s", err.Error())
}
subPath := "/abc"
actualPath, err := zk.Create(subPath, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll))
if err != nil {
t.Fatalf("Chroot create path error, %s", err.Error())
}
if actualPath != subPath {
t.Fatalf("path expect %s, got %s", path+subPath, actualPath)
}
err = zk.Chroot("")
if err != nil {
t.Fatalf("Chroot error, %s", err.Error())
}
exists, _, err := zk.Exists(path + subPath)
if !exists || err != nil {
t.Fatal("Chroot error as node exists")
}
}

func TestRequestFail(t *testing.T) {
// If connecting fails to all servers in the list then pending requests
// should be errored out so they don't hang forever.
Expand Down