diff --git a/.gitignore b/.gitignore index f4d432a..439bf5c 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ # Dependency directories (remove the comment below to include it) # vendor/ +.idea/ diff --git a/addr.go b/addr.go new file mode 100644 index 0000000..9326d39 --- /dev/null +++ b/addr.go @@ -0,0 +1,20 @@ +package zocket + +// Addr represents a (the) WebSocket end point address. +// Since WebSocket connections are upgraded HTTP Connections, +//the Listener always has a dummy address +type Addr struct { + network string + address string +} + +// Implement the net.Addr interface + +func (a Addr) Network() string { return a.network } +func (a Addr) String() string { return a.address } + +// Some default Addr for internal use +var addr = Addr{ + network: "ws", + address: "[WebSocket]", +} diff --git a/client_conn.go b/client_conn.go new file mode 100644 index 0000000..fe63db8 --- /dev/null +++ b/client_conn.go @@ -0,0 +1,167 @@ +// +build js, wasm + +package zocket // import udico.de/dreem/zocket + +import ( + "context" + "errors" + "fmt" + "net" + "runtime/debug" + "sync" + "syscall/js" + "time" +) + +type handler func(this js.Value, args []js.Value) interface{} + +// ClientConnection is a TCP like Connection between the wasm and a server. +type ClientConnection struct { + M sync.Mutex + ws js.Value + conn chan struct{} + in chan []byte + onMessage handler + buffered []byte +} + +func makeHandler(c chan []byte) handler { + return func(this js.Value, args []js.Value) interface{} { + e := args[0] + array := e.Get("data") // is an arraybuffer + u8arr := js.Global().Get("Uint8Array").New(array) + //fmt.Printf("arr: %v, %+v\n", u8arr.Type(), u8arr) + bytes := make([]byte, u8arr.Get("byteLength").Int()) + /* ln := */ js.CopyBytesToGo(bytes, u8arr) + //fmt.Printf("got %v bytes [ID: %v]\n", ln, e.Get("lastEventId").String()) + //fmt.Printf("message: %+v\n", bytes) + c <- bytes + return nil + } +} + +func Dial(ctx context.Context, target string) (net.Conn, error) { + fmt.Printf("dialing %v\n", target) + ret := ClientConnection{ + ws: js.Global().Get("WebSocket").New(target), + in: make(chan []byte), + conn: make(chan struct{}), + } + ret.ws.Set("binaryType", "arraybuffer") + ret.onMessage = makeHandler(ret.in) + + ret.ws.Call("addEventListener", "open", js.FuncOf( + func(this js.Value, args []js.Value) interface{} { + fmt.Println("Opened") + close(ret.conn) + return nil + })) + + ret.ws.Call("addEventListener", "error", js.FuncOf( + func(this js.Value, args []js.Value) interface{} { + fmt.Println("error") + return nil + })) + + ret.ws.Call("addEventListener", "close", js.FuncOf( + func(this js.Value, args []js.Value) interface{} { + fmt.Println("close") + return nil + })) + + // MessageEvent: [all ro] + // .data - the data sent + // .origin + // .lastEventId + // .source - + ret.ws.Call("addEventListener", "message", js.FuncOf(ret.onMessage)) + + <-ret.conn // block until the connection is really open + return ret, nil +} + + +// Read gets some bytes from a ws frame. Blocks. +func (c ClientConnection) Read(b []byte) (int, error) { + fmt.Println("client_read") + c.M.Lock() + defer c.M.Unlock() + if len(c.buffered) > 0 { + for i, _ := range b { + if i >= len(c.buffered) { + c.buffered = nil + return i, nil + } + b[i] = c.buffered[i] + } + if len(c.buffered) > len(b) { + c.buffered = c.buffered[len(b):] + } + return len(b), nil + } + var err error = nil + bytes, ok := <-c.in + if !ok { + err = errors.New("read from closed connection") + } + for i, _ := range b { + if i >= len(bytes) { + fmt.Printf("read %v bytes\n", i) + return i, nil + } + b[i] = bytes[i] + } + if len(bytes) > len(b) { + c.buffered = bytes[len(b):] + } + fmt.Printf("read %v bytes\n", len(b)) + return len(b), err +} + +// Write puts some bytes packed into websocket frames on the wire. +func (c ClientConnection) Write(b []byte) (int, error) { + array := js.Global().Get("Uint8Array").New(len(b)) + n := js.CopyBytesToJS(array, b) + c.ws.Call("send", array) + return n, nil +} + +// Close terminates the connection nicely. +func (c ClientConnection) Close() error { + fmt.Println("client_close") + debug.PrintStack() + c.ws.Call("close") + return nil +} + +// LocalAddr returns the adress of the local endpoint. Since we operate within a sandbox we don't have any. return a +// dummy and make underlying layers think they're using a tcp conection. +func (c ClientConnection) LocalAddr() net.Addr { + fmt.Println("DBG: call LocalAddr() on ClientConnection") + return Addr { + network: "tcp", + address: "0.0.0.0", + } +} + +// RemoteAddr returns the adress of the remote endpoint. +func (c ClientConnection) RemoteAddr() net.Addr { + fmt.Println("DBG: call RemoteAddr() on ClientConnection") + return Addr { + network: "tcp", + // TODO: Use the real servers adress somehow (or none at all? :think:) + address: "127.0.0.1", + } +} + +func (c ClientConnection) SetDeadline(t time.Time) error { + return errors.New("SetDeadline not implemented") +} + +func (c ClientConnection) SetReadDeadline(t time.Time) error { + return errors.New("SetReadDeadline not implemented") +} + +func (c ClientConnection) SetWriteDeadline(t time.Time) error { + return errors.New("SetWriteDeadline not implemented") +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..1dfe4b6 --- /dev/null +++ b/doc.go @@ -0,0 +1,3 @@ +// Zocket transparently tunnels a TCP connection through a WebSocket layer from within WebAssembly. + +package zocket diff --git a/frame.go b/frame.go new file mode 100644 index 0000000..3909d85 --- /dev/null +++ b/frame.go @@ -0,0 +1,161 @@ +package zocket + +import ( + "encoding/binary" + "errors" + "net" + "sync" +) + +type FrameType uint8 + +const ( + FrameType_Cont FrameType = 0x00 + FrameType_Text FrameType = 0x01 + FrameType_Binary FrameType = 0x02 + FrameType_Close FrameType = 0x08 + FrameType_Ping FrameType = 0x09 + FrameType_Pong FrameType = 0x0a +) + +/* + * + */ +type Frame struct { + Fin bool + Opcode FrameType + Masked bool + Mask [4]byte + Len uint64 + Payload []byte +} + +// Read a Frame from the wire. +// Blocks until a complete frame has been retrieved. +func ReadFrame(conn net.Conn) (*Frame, error) { + tFrame := &Frame{} + head := make([]byte, 2, 2) + n, err := conn.Read(head) + if err != nil { + return nil, err + } + if n != len(head) { + return nil, errors.New("incomplete header") + } + + // header sanity checks + tFrame.Fin = (head[0] & 128) != 0 + rsv1 := (head[0] & 64) != 0 + rsv2 := (head[0] & 32) != 0 + rsv3 := (head[0] & 16) != 0 + if rsv1 || rsv2 || rsv3 { + return nil, errors.New("invalid frame header") + } + tFrame.Opcode = FrameType(head[0] & 15) + + tFrame.Masked = (head[1] & 128) != 0 + tFrame.Len = uint64(head[1] & 127) + + // read the extended lenght fields, if required + var tLenBuffer [8]byte + var tLenBufPtr []byte + switch tFrame.Len { + case 126: + tLenBufPtr = tLenBuffer[6:] + case 127: + tLenBufPtr = tLenBuffer[:] + default: + tLenBufPtr = nil + } + if tLenBufPtr != nil { + n, err = conn.Read(tLenBufPtr) + if err != nil { + return nil, err + } + if n != len(tLenBufPtr) { + return nil, errors.New("incomplete length field") + } + tFrame.Len = binary.BigEndian.Uint64(tLenBuffer[:]) + } + + if tFrame.Masked { + n, err = conn.Read(tFrame.Mask[:]) + if err != nil { + return nil, err + } + if n != len(tFrame.Mask) { + return nil, errors.New("incomplete mask") + } + } + + // read the payload + tFrame.Payload = make([]byte, tFrame.Len, tFrame.Len) + len := uint64(0) + waiter := sync.WaitGroup{} + for len != tFrame.Len { + n, err := conn.Read(tFrame.Payload[len:]) + if err != nil { + return nil, err + } + if tFrame.Masked { + waiter.Add(1) + go func(m, n uint64) { + for i := m; i < m+n; i++ { + tFrame.Payload[i] = tFrame.Payload[i] ^ tFrame.Mask[i%4] + } + waiter.Done() + }(len, uint64(n)) + } + len += uint64(n) + } + waiter.Wait() + return tFrame, nil +} + +// Write this Frame to the wire. +func (f *Frame) WriteTo(conn net.Conn) error { + var head [2]byte + if f.Fin { + head[0] |= 1 << 7 + } + head[0] |= uint8(f.Opcode) & 15 + if f.Masked { + head[1] |= 1 << 7 + } + var lbuf []byte = nil + if len(f.Payload) < 126 { + head[1] |= byte(len(f.Payload) & 127) + } else if len(f.Payload) < 65536 { + head[1] |= 126 + lbuf = make([]byte, 2, 2) + binary.BigEndian.PutUint16(lbuf, uint16(len(f.Payload))) + } else { + head[1] |= 127 + lbuf = make([]byte, 8, 8) + // only 63 of 64 bits can be used - no problem for us: + // maximum slice size in go is max(int) so we never ever will reach bit 63 and we do not have to reset that bit. + binary.BigEndian.PutUint64(lbuf, uint64(len(f.Payload))) + } + n, err := conn.Write(head[:]) + if err != nil { + return err + } + if n != len(head) { + return errors.New("partial head write") + } + if lbuf != nil { + conn.Write(lbuf) + } + if f.Masked { + conn.Write(f.Mask[:]) + } + l := 0 + for l < len(f.Payload) { + n, err = conn.Write(f.Payload[l:]) + if err != nil { + return err + } + l += n + } + return nil +} \ No newline at end of file diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..6bf01d0 --- /dev/null +++ b/listener.go @@ -0,0 +1,172 @@ +package zocket + +import ( + "bufio" + "bytes" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "net" + "net/http" +) + +// A Listener as defined in net.Listener for accepting WebSocket connections. +type Listener struct { + cConn chan net.Conn +} + +// NewListener creates a new listener for WebSocket connections. +// You need to register the Listener as a HTTP handler in order for it to become operative. +func NewListener() Listener { + return Listener{ + cConn: make(chan net.Conn), + } +} + +// Implement the net.Listener interface + +// Accept waits for and returns the next connection to the listener. +// +// The Connection may be treated like any other streaming network connection, +// but wraps the traffic internally to WebSocket frames. +// +// This function blocks. +func (l Listener) Accept() (net.Conn, error) { + tConn, ok := <-l.cConn + if !ok { + return nil, errors.New("socket closed") + } + ret := ServerConnection{ + conn: tConn, + } + return ret, nil +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l Listener) Close() error { + select { + case _, ok := <-l.cConn: + // either the channel closed or we just read a connection from it. + // we discard the connection, since we're about to close the channel anyways + if !ok { + return errors.New("channel already closed") + } + default: + // something blocks. the channel is alive + } + close(l.cConn) + return nil +} + +// Addr returns the listener's network address. +// By their nature, WebSocket connections are upgraded HTTP connections and don't have +// a real end point address. So a default address is always returned. +func (l Listener) Addr() net.Addr { + return addr +} + +// Implement http handler interface + +// ServeHTTP handles a WebSocket upgrade request. +// You should make sure, a call to this function is really done on an update request, +// since it would respond with a http error otherwise. +func (l Listener) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + // Sanity checks: + // - method is HTTP GET + // - has request URI + // - has Host header -> no! + // - has Connection: Upgrade + // - has Upgrade: websocket + // - has Sec-WebSocket-Version: 13 + // - has Sec-WebSocket-Key + // - Origin: is optional + // - Sec-WebSocket-Protocol is optional + // - Sec-WebSocket-Extensions is optional + + if req.Method != "GET" { + log.Error("Invalid method!") + resp.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if req.RequestURI == "" { + log.Errorf("Invalid Request URI: '%v'", req.RequestURI) + resp.WriteHeader(http.StatusNotAcceptable) // ??? + return + } + + //if req.Header.Get("Host") == "" { + // log.Error("No Host") + // resp.WriteHeader(http.StatusExpectationFailed) // ??? + // return + //} + + if req.Header.Get("Connection") != "Upgrade" { + log.Error("Connection != Upgrade") + resp.WriteHeader(http.StatusUpgradeRequired) + return + } + + if req.Header.Get("Upgrade") != "websocket" { + log.Error("Upgrade != websocket") + resp.WriteHeader(http.StatusUpgradeRequired) + return + } + + if req.Header.Get("Sec-WebSocket-Version") != "13" { + log.Error("Invalid WebSocket Version") + resp.WriteHeader(http.StatusUpgradeRequired) + return + } + + // Todo: remove all that debug spam + log.Warn("In the websocket handler - yay!") + for k, v := range req.Header { + log.Infof(" %v: %v", k, v[0]) + } + + wsKey := req.Header.Get("Sec-Websocket-Key") + wsKey += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + shakey := sha1.Sum([]byte(wsKey)) + wsKey = base64.StdEncoding.EncodeToString(shakey[:]) + + log.Info("Generated WebSocket-Accept: ", wsKey) + // NOPE: we need to hijack the connection and send it manually! + //resp.Header().Set("Upgrade", "websocket") + //resp.Header().Set("Connection", "Upgrade") + //resp.Header().Set("Sec-WebSocket-Accept", wsKey) + //resp.WriteHeader(101) + + h, ok := resp.(http.Hijacker) + if !ok { + log.Error("Missing Hijacker extension on responsewriter") + return + } + var brw *bufio.ReadWriter // only used to check for data sent by the client (which is disallowed) + conn, brw, err := h.Hijack() // returns the connection and the readwriter to operate on + if err != nil { + log.WithError(err).Error("cannot hijack") + return + } + + buf := &bytes.Buffer{} + buf.WriteString(fmt.Sprintf("%v 101 Switching Protocols\r\n", req.Proto)) // per RFC: HTTP/1.1 or higher + buf.WriteString("Upgrade: websocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %v\r\n\r\n", wsKey)) + + // the client must not've sent any data before the handshake is complete + if brw.Reader.Buffered() > 0 { + conn.Close() + log.Error("client send before handshake!") + return + } + + // fire! + conn.Write(buf.Bytes()) + + l.cConn <- conn +} diff --git a/server_conn.go b/server_conn.go new file mode 100644 index 0000000..4b95e09 --- /dev/null +++ b/server_conn.go @@ -0,0 +1,121 @@ +package zocket + +import ( + "bytes" + "errors" + "fmt" + "github.com/sirupsen/logrus" + "net" + "time" +) + +// ServerConnection is the serverside part of the ws abstraction +type ServerConnection struct { + conn net.Conn // The underlying REAL connection. + buffered []byte +} + +// Read gets some bytes from a ws frame. Blocks. +func (c ServerConnection) Read(b []byte) (int, error) { + if len(c.buffered) > 0 { + for i, _ := range b { + if i >= len(c.buffered) { + c.buffered = nil + return i, nil + } + b[i] = c.buffered[i] + } + if len(c.buffered) > len(b) { + c.buffered = c.buffered[len(b):] + } + return len(b), nil + } + buf := &bytes.Buffer{}// + datloop: + for { + tFrame, err := ReadFrame(c.conn) + if err != nil { + return 0, err + } + switch tFrame.Opcode { + case FrameType_Ping: + // send pong + tFrame.Opcode = FrameType_Pong + tFrame.WriteTo(c.conn) + case FrameType_Pong: + // decide + + case FrameType_Close: + c.Close() + return 0, errors.New(fmt.Sprintf("connection closed: %v", tFrame.Payload[:])) + case FrameType_Binary: + buf.Write(tFrame.Payload) + if tFrame.Fin { + break datloop + } + default: + logrus.Errorf("uhandled frame type: %v", tFrame.Opcode) + } + } + bytes := buf.Bytes() + for i, _ := range b { + if i >= len(bytes) { + return i, nil + } + b[i] = bytes[i] + } + if len(bytes) > len(b) { + c.buffered = bytes[len(b):] + } + return len(b), nil +} + +// Write puts some bytes packed into websocket frames on the wire. +func (c ServerConnection) Write(b []byte) (int, error) { + tFrame := &Frame{ + Fin: true, + Opcode: FrameType_Binary, + Masked: false, + Mask: [4]byte{}, + Len: uint64(len(b)), + Payload: b, + } + err := tFrame.WriteTo(c.conn) + return len(b), err +} + +// Close terminates the connection nicely. +func (c ServerConnection) Close() error { + logrus.Debugf("serverconn_close %v", c) + //debug.PrintStack() + tFrame := &Frame{ + Fin: true, + Opcode: FrameType_Close, + Masked: false, + Mask: [4]byte{}, + Len: 0, + Payload: nil, + } + _ = tFrame.WriteTo(c.conn) + return c.conn.Close() +} + +func (c ServerConnection) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c ServerConnection) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c ServerConnection) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c ServerConnection) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c ServerConnection) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +}