Go Home Page
The Go Programming Language

Source file src/pkg/netchan/common.go

// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package netchan

import (
    "gob"
    "net"
    "os"
    "reflect"
    "sync"
    "time"
)

// The direction of a connection from the client's perspective.
type Dir int

const (
    Recv Dir = iota
    Send
)

// Payload types
const (
    payRequest = iota // request structure follows
    payError          // error structure follows
    payData           // user payload follows
    payAck            // acknowledgement; no payload
)

// A header is sent as a prefix to every transmission.  It will be followed by
// a request structure, an error structure, or an arbitrary user payload structure.
type header struct {
    name        string
    payloadType int
    seqNum      int64
}

// Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count
// messages.  If count is zero, it means unlimited.
type request struct {
    count int64
    dir   Dir
}

// Sent with a header to report an error.
type error struct {
    error string
}

// Used to unify management of acknowledgements for import and export.
type unackedCounter interface {
    unackedCount() int64
    ack() int64
    seq() int64
}

// A channel and its direction.
type chanDir struct {
    ch  *reflect.ChanValue
    dir Dir
}

// clientSet contains the objects and methods needed for tracking
// clients of an exporter and draining outstanding messages.
type clientSet struct {
    mu      sync.Mutex // protects access to channel and client maps
    chans   map[string]*chanDir
    clients map[unackedCounter]bool
}

// Mutex-protected encoder and decoder pair.
type encDec struct {
    decLock sync.Mutex
    dec     *gob.Decoder
    encLock sync.Mutex
    enc     *gob.Encoder
}

func newEncDec(conn net.Conn) *encDec {
    return &encDec{
        dec: gob.NewDecoder(conn),
        enc: gob.NewEncoder(conn),
    }
}

// Decode an item from the connection.
func (ed *encDec) decode(value reflect.Value) os.Error {
    ed.decLock.Lock()
    err := ed.dec.DecodeValue(value)
    if err != nil {
        // TODO: tear down connection?
    }
    ed.decLock.Unlock()
    return err
}

// Encode a header and payload onto the connection.
func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.Error {
    ed.encLock.Lock()
    hdr.payloadType = payloadType
    err := ed.enc.Encode(hdr)
    if err == nil {
        if payload != nil {
            err = ed.enc.Encode(payload)
        }
    }
    if err != nil {
        // TODO: tear down connection if there is an error?
    }
    ed.encLock.Unlock()
    return err
}

// See the comment for Exporter.Drain.
func (cs *clientSet) drain(timeout int64) os.Error {
    startTime := time.Nanoseconds()
    for {
        pending := false
        cs.mu.Lock()
        // Any messages waiting for a client?
        for _, chDir := range cs.chans {
            if chDir.ch.Len() > 0 {
                pending = true
            }
        }
        // Any unacknowledged messages?
        for client := range cs.clients {
            n := client.unackedCount()
            if n > 0 { // Check for > rather than != just to be safe.
                pending = true
                break
            }
        }
        cs.mu.Unlock()
        if !pending {
            break
        }
        if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
            return os.ErrorString("timeout")
        }
        time.Sleep(100 * 1e6) // 100 milliseconds
    }
    return nil
}

// See the comment for Exporter.Sync.
func (cs *clientSet) sync(timeout int64) os.Error {
    startTime := time.Nanoseconds()
    // seq remembers the clients and their seqNum at point of entry.
    seq := make(map[unackedCounter]int64)
    for client := range cs.clients {
        seq[client] = client.seq()
    }
    for {
        pending := false
        cs.mu.Lock()
        // Any unacknowledged messages?  Look only at clients that existed
        // when we started and are still in this client set.
        for client := range seq {
            if _, ok := cs.clients[client]; ok {
                if client.ack() < seq[client] {
                    pending = true
                    break
                }
            }
        }
        cs.mu.Unlock()
        if !pending {
            break
        }
        if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
            return os.ErrorString("timeout")
        }
        time.Sleep(100 * 1e6) // 100 milliseconds
    }
    return nil
}