diff options
Diffstat (limited to 'xmpp')
-rw-r--r-- | xmpp/session.go | 9 | ||||
-rw-r--r-- | xmpp/xml.go | 69 |
2 files changed, 34 insertions, 44 deletions
diff --git a/xmpp/session.go b/xmpp/session.go index b4a8fab..6a5e646 100644 --- a/xmpp/session.go +++ b/xmpp/session.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "crypto/tls" "crypto/x509" "encoding/xml" @@ -18,7 +19,6 @@ type session struct { in chan any out chan<- any transport *tls.Conn - ed encoderDecoder tx *xml.Encoder rx chan xml.Token resourceReq string @@ -47,9 +47,10 @@ func (s *session) run() { } defer s.transport.Close() - s.ed = newEncoderDecoder(s) - go s.ed.run() - defer func() { s.ed.terminator <- true }() + ctx, cancel := context.WithCancel(context.Background()) + cpy := s.rx + go runRx(ctx, cpy, s.transport) + defer cancel() lw := logger{"[TX] "} w := io.MultiWriter(s.transport, lw) diff --git a/xmpp/xml.go b/xmpp/xml.go index f547210..14c6637 100644 --- a/xmpp/xml.go +++ b/xmpp/xml.go @@ -1,57 +1,26 @@ package xmpp import ( + "context" + "crypto/tls" "encoding/xml" "errors" "io" "log" ) -type encoderDecoder struct { - session *session - rx *xml.Decoder - terminator chan bool -} - -func newEncoderDecoder(s *session) encoderDecoder { - ed := encoderDecoder{} - - ed.session = s +func runRx(ctx context.Context, chn chan xml.Token, conn *tls.Conn) { - lr := logger{"[RX] "} - r := io.TeeReader(s.transport, lr) - ed.rx = xml.NewDecoder(r) + l := logger{"[RX] "} + r := io.TeeReader(conn, l) + d := xml.NewDecoder(r) - return ed -} - -func (s *session) encodeToken(t xml.Token) error { - var err error - defer func() { - if err != nil { - log.Println(err) - } - }() - - err = s.tx.EncodeToken(t) - if err != nil { - return err - } - err = s.tx.Flush() - if err != nil { - return err - } - - return nil -} - -func (ed *encoderDecoder) run() { for { select { - case <-ed.terminator: + case <-ctx.Done(): return default: - t, err := ed.rx.Token() + t, err := d.Token() if t != nil && err == nil { switch t.(type) { case xml.ProcInst: @@ -59,7 +28,7 @@ func (ed *encoderDecoder) run() { case xml.Comment: default: c := xml.CopyToken(t) - ed.session.rx <- c + chn <- c } } if err != nil { @@ -72,3 +41,23 @@ func (ed *encoderDecoder) run() { } } } + +func (s *session) encodeToken(t xml.Token) error { + var err error + defer func() { + if err != nil { + log.Println(err) + } + }() + + err = s.tx.EncodeToken(t) + if err != nil { + return err + } + err = s.tx.Flush() + if err != nil { + return err + } + + return nil +} |