package cookiejar
import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
)
type PublicSuffixList interface {
PublicSuffix(domain string) string
String() string
}
type Options struct {
PublicSuffixList PublicSuffixList
}
type Jar struct {
psList PublicSuffixList
mu sync.Mutex
entries map[string]map[string]entry
nextSeqNum uint64
}
func New(o *Options) (*Jar, error) {
jar := &Jar{
entries: make(map[string]map[string]entry),
}
if o != nil {
jar.psList = o.PublicSuffixList
}
return jar, nil
}
type entry struct {
Name string
Value string
Domain string
Path string
Secure bool
HttpOnly bool
Persistent bool
HostOnly bool
Expires time.Time
Creation time.Time
LastAccess time.Time
seqNum uint64
}
func (e *entry) id() string {
return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
}
func (e *entry) shouldSend(https bool, host, path string) bool {
return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
}
func (e *entry) domainMatch(host string) bool {
if e.Domain == host {
return true
}
return !e.HostOnly && hasDotSuffix(host, e.Domain)
}
func (e *entry) pathMatch(requestPath string) bool {
if requestPath == e.Path {
return true
}
if strings.HasPrefix(requestPath, e.Path) {
if e.Path[len(e.Path)-1] == '/' {
return true
} else if requestPath[len(e.Path)] == '/' {
return true
}
}
return false
}
func hasDotSuffix(s, suffix string) bool {
return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
}
func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
return j.cookies(u, time.Now())
}
func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
if u.Scheme != "http" && u.Scheme != "https" {
return cookies
}
host, err := canonicalHost(u.Host)
if err != nil {
return cookies
}
key := jarKey(host, j.psList)
j.mu.Lock()
defer j.mu.Unlock()
submap := j.entries[key]
if submap == nil {
return cookies
}
https := u.Scheme == "https"
path := u.Path
if path == "" {
path = "/"
}
modified := false
var selected []entry
for id, e := range submap {
if e.Persistent && !e.Expires.After(now) {
delete(submap, id)
modified = true
continue
}
if !e.shouldSend(https, host, path) {
continue
}
e.LastAccess = now
submap[id] = e
selected = append(selected, e)
modified = true
}
if modified {
if len(submap) == 0 {
delete(j.entries, key)
} else {
j.entries[key] = submap
}
}
sort.Slice(selected, func(i, j int) bool {
s := selected
if len(s[i].Path) != len(s[j].Path) {
return len(s[i].Path) > len(s[j].Path)
}
if !s[i].Creation.Equal(s[j].Creation) {
return s[i].Creation.Before(s[j].Creation)
}
return s[i].seqNum < s[j].seqNum
})
for _, e := range selected {
cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
}
return cookies
}
func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
j.setCookies(u, cookies, time.Now())
}
func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
if len(cookies) == 0 {
return
}
if u.Scheme != "http" && u.Scheme != "https" {
return
}
host, err := canonicalHost(u.Host)
if err != nil {
return
}
key := jarKey(host, j.psList)
defPath := defaultPath(u.Path)
j.mu.Lock()
defer j.mu.Unlock()
submap := j.entries[key]
modified := false
for _, cookie := range cookies {
e, remove, err := j.newEntry(cookie, now, defPath, host)
if err != nil {
continue
}
id := e.id()
if remove {
if submap != nil {
if _, ok := submap[id]; ok {
delete(submap, id)
modified = true
}
}
continue
}
if submap == nil {
submap = make(map[string]entry)
}
if old, ok := submap[id]; ok {
e.Creation = old.Creation
e.seqNum = old.seqNum
} else {
e.Creation = now
e.seqNum = j.nextSeqNum
j.nextSeqNum++
}
e.LastAccess = now
submap[id] = e
modified = true
}
if modified {
if len(submap) == 0 {
delete(j.entries, key)
} else {
j.entries[key] = submap
}
}
}
func canonicalHost(host string) (string, error) {
var err error
host = strings.ToLower(host)
if hasPort(host) {
host, _, err = net.SplitHostPort(host)
if err != nil {
return "", err
}
}
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
return toASCII(host)
}
func hasPort(host string) bool {
colons := strings.Count(host, ":")
if colons == 0 {
return false
}
if colons == 1 {
return true
}
return host[0] == '[' && strings.Contains(host, "]:")
}
func jarKey(host string, psl PublicSuffixList) string {
if isIP(host) {
return host
}
var i int
if psl == nil {
i = strings.LastIndex(host, ".")
if i <= 0 {
return host
}
} else {
suffix := psl.PublicSuffix(host)
if suffix == host {
return host
}
i = len(host) - len(suffix)
if i <= 0 || host[i-1] != '.' {
return host
}
}
prevDot := strings.LastIndex(host[:i-1], ".")
return host[prevDot+1:]
}
func isIP(host string) bool {
return net.ParseIP(host) != nil
}
func defaultPath(path string) string {
if len(path) == 0 || path[0] != '/' {
return "/"
}
i := strings.LastIndex(path, "/")
if i == 0 {
return "/"
}
return path[:i]
}
func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
e.Name = c.Name
if c.Path == "" || c.Path[0] != '/' {
e.Path = defPath
} else {
e.Path = c.Path
}
e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
if err != nil {
return e, false, err
}
if c.MaxAge < 0 {
return e, true, nil
} else if c.MaxAge > 0 {
e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
e.Persistent = true
} else {
if c.Expires.IsZero() {
e.Expires = endOfTime
e.Persistent = false
} else {
if !c.Expires.After(now) {
return e, true, nil
}
e.Expires = c.Expires
e.Persistent = true
}
}
e.Value = c.Value
e.Secure = c.Secure
e.HttpOnly = c.HttpOnly
return e, false, nil
}
var (
errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
errNoHostname = errors.New("cookiejar: no host name available (IP only)")
)
var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
if domain == "" {
return host, true, nil
}
if isIP(host) {
return "", false, errNoHostname
}
if domain[0] == '.' {
domain = domain[1:]
}
if len(domain) == 0 || domain[0] == '.' {
return "", false, errMalformedDomain
}
domain = strings.ToLower(domain)
if domain[len(domain)-1] == '.' {
return "", false, errMalformedDomain
}
if j.psList != nil {
if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
if host == domain {
return host, true, nil
}
return "", false, errIllegalDomain
}
}
if host != domain && !hasDotSuffix(host, domain) {
return "", false, errIllegalDomain
}
return domain, false, nil
}
View as plain text