Source file src/pkg/net/rpc/server.go
1
2
3
4
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
122 package rpc
123
124 import (
125 "bufio"
126 "encoding/gob"
127 "errors"
128 "io"
129 "log"
130 "net"
131 "net/http"
132 "reflect"
133 "strings"
134 "sync"
135 "unicode"
136 "unicode/utf8"
137 )
138
139 const (
140
141 DefaultRPCPath = "/_goRPC_"
142 DefaultDebugPath = "/debug/rpc"
143 )
144
145
146
147 var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
148
149 type methodType struct {
150 sync.Mutex
151 method reflect.Method
152 ArgType reflect.Type
153 ReplyType reflect.Type
154 numCalls uint
155 }
156
157 type service struct {
158 name string
159 rcvr reflect.Value
160 typ reflect.Type
161 method map[string]*methodType
162 }
163
164
165
166
167 type Request struct {
168 ServiceMethod string
169 Seq uint64
170 next *Request
171 }
172
173
174
175
176 type Response struct {
177 ServiceMethod string
178 Seq uint64
179 Error string
180 next *Response
181 }
182
183
184 type Server struct {
185 mu sync.RWMutex
186 serviceMap map[string]*service
187 reqLock sync.Mutex
188 freeReq *Request
189 respLock sync.Mutex
190 freeResp *Response
191 }
192
193
194 func NewServer() *Server {
195 return &Server{serviceMap: make(map[string]*service)}
196 }
197
198
199 var DefaultServer = NewServer()
200
201
202 func isExported(name string) bool {
203 rune, _ := utf8.DecodeRuneInString(name)
204 return unicode.IsUpper(rune)
205 }
206
207
208 func isExportedOrBuiltinType(t reflect.Type) bool {
209 for t.Kind() == reflect.Ptr {
210 t = t.Elem()
211 }
212
213
214 return isExported(t.Name()) || t.PkgPath() == ""
215 }
216
217
218
219
220
221
222
223
224
225
226 func (server *Server) Register(rcvr interface{}) error {
227 return server.register(rcvr, "", false)
228 }
229
230
231
232 func (server *Server) RegisterName(name string, rcvr interface{}) error {
233 return server.register(rcvr, name, true)
234 }
235
236 func (server *Server) register(rcvr interface{}, name string, useName bool) error {
237 server.mu.Lock()
238 defer server.mu.Unlock()
239 if server.serviceMap == nil {
240 server.serviceMap = make(map[string]*service)
241 }
242 s := new(service)
243 s.typ = reflect.TypeOf(rcvr)
244 s.rcvr = reflect.ValueOf(rcvr)
245 sname := reflect.Indirect(s.rcvr).Type().Name()
246 if useName {
247 sname = name
248 }
249 if sname == "" {
250 log.Fatal("rpc: no service name for type", s.typ.String())
251 }
252 if !isExported(sname) && !useName {
253 s := "rpc Register: type " + sname + " is not exported"
254 log.Print(s)
255 return errors.New(s)
256 }
257 if _, present := server.serviceMap[sname]; present {
258 return errors.New("rpc: service already defined: " + sname)
259 }
260 s.name = sname
261 s.method = make(map[string]*methodType)
262
263
264 s.method = suitableMethods(s.typ, true)
265
266 if len(s.method) == 0 {
267 str := ""
268
269 method := suitableMethods(reflect.PtrTo(s.typ), false)
270 if len(method) != 0 {
271 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
272 } else {
273 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
274 }
275 log.Print(str)
276 return errors.New(str)
277 }
278 server.serviceMap[s.name] = s
279 return nil
280 }
281
282
283
284 func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
285 methods := make(map[string]*methodType)
286 for m := 0; m < typ.NumMethod(); m++ {
287 method := typ.Method(m)
288 mtype := method.Type
289 mname := method.Name
290
291 if method.PkgPath != "" {
292 continue
293 }
294
295 if mtype.NumIn() != 3 {
296 if reportErr {
297 log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
298 }
299 continue
300 }
301
302 argType := mtype.In(1)
303 if !isExportedOrBuiltinType(argType) {
304 if reportErr {
305 log.Println(mname, "argument type not exported:", argType)
306 }
307 continue
308 }
309
310 replyType := mtype.In(2)
311 if replyType.Kind() != reflect.Ptr {
312 if reportErr {
313 log.Println("method", mname, "reply type not a pointer:", replyType)
314 }
315 continue
316 }
317
318 if !isExportedOrBuiltinType(replyType) {
319 if reportErr {
320 log.Println("method", mname, "reply type not exported:", replyType)
321 }
322 continue
323 }
324
325 if mtype.NumOut() != 1 {
326 if reportErr {
327 log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
328 }
329 continue
330 }
331
332 if returnType := mtype.Out(0); returnType != typeOfError {
333 if reportErr {
334 log.Println("method", mname, "returns", returnType.String(), "not error")
335 }
336 continue
337 }
338 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
339 }
340 return methods
341 }
342
343
344
345
346 var invalidRequest = struct{}{}
347
348 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
349 resp := server.getResponse()
350
351 resp.ServiceMethod = req.ServiceMethod
352 if errmsg != "" {
353 resp.Error = errmsg
354 reply = invalidRequest
355 }
356 resp.Seq = req.Seq
357 sending.Lock()
358 err := codec.WriteResponse(resp, reply)
359 if err != nil {
360 log.Println("rpc: writing response:", err)
361 }
362 sending.Unlock()
363 server.freeResponse(resp)
364 }
365
366 func (m *methodType) NumCalls() (n uint) {
367 m.Lock()
368 n = m.numCalls
369 m.Unlock()
370 return n
371 }
372
373 func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
374 mtype.Lock()
375 mtype.numCalls++
376 mtype.Unlock()
377 function := mtype.method.Func
378
379 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
380
381 errInter := returnValues[0].Interface()
382 errmsg := ""
383 if errInter != nil {
384 errmsg = errInter.(error).Error()
385 }
386 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
387 server.freeRequest(req)
388 }
389
390 type gobServerCodec struct {
391 rwc io.ReadWriteCloser
392 dec *gob.Decoder
393 enc *gob.Encoder
394 encBuf *bufio.Writer
395 }
396
397 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
398 return c.dec.Decode(r)
399 }
400
401 func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
402 return c.dec.Decode(body)
403 }
404
405 func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) {
406 if err = c.enc.Encode(r); err != nil {
407 return
408 }
409 if err = c.enc.Encode(body); err != nil {
410 return
411 }
412 return c.encBuf.Flush()
413 }
414
415 func (c *gobServerCodec) Close() error {
416 return c.rwc.Close()
417 }
418
419
420
421
422
423
424 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
425 buf := bufio.NewWriter(conn)
426 srv := &gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(buf), buf}
427 server.ServeCodec(srv)
428 }
429
430
431
432 func (server *Server) ServeCodec(codec ServerCodec) {
433 sending := new(sync.Mutex)
434 for {
435 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
436 if err != nil {
437 if err != io.EOF {
438 log.Println("rpc:", err)
439 }
440 if !keepReading {
441 break
442 }
443
444 if req != nil {
445 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
446 server.freeRequest(req)
447 }
448 continue
449 }
450 go service.call(server, sending, mtype, req, argv, replyv, codec)
451 }
452 codec.Close()
453 }
454
455
456
457 func (server *Server) ServeRequest(codec ServerCodec) error {
458 sending := new(sync.Mutex)
459 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
460 if err != nil {
461 if !keepReading {
462 return err
463 }
464
465 if req != nil {
466 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
467 server.freeRequest(req)
468 }
469 return err
470 }
471 service.call(server, sending, mtype, req, argv, replyv, codec)
472 return nil
473 }
474
475 func (server *Server) getRequest() *Request {
476 server.reqLock.Lock()
477 req := server.freeReq
478 if req == nil {
479 req = new(Request)
480 } else {
481 server.freeReq = req.next
482 *req = Request{}
483 }
484 server.reqLock.Unlock()
485 return req
486 }
487
488 func (server *Server) freeRequest(req *Request) {
489 server.reqLock.Lock()
490 req.next = server.freeReq
491 server.freeReq = req
492 server.reqLock.Unlock()
493 }
494
495 func (server *Server) getResponse() *Response {
496 server.respLock.Lock()
497 resp := server.freeResp
498 if resp == nil {
499 resp = new(Response)
500 } else {
501 server.freeResp = resp.next
502 *resp = Response{}
503 }
504 server.respLock.Unlock()
505 return resp
506 }
507
508 func (server *Server) freeResponse(resp *Response) {
509 server.respLock.Lock()
510 resp.next = server.freeResp
511 server.freeResp = resp
512 server.respLock.Unlock()
513 }
514
515 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
516 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
517 if err != nil {
518 if !keepReading {
519 return
520 }
521
522 codec.ReadRequestBody(nil)
523 return
524 }
525
526
527 argIsValue := false
528 if mtype.ArgType.Kind() == reflect.Ptr {
529 argv = reflect.New(mtype.ArgType.Elem())
530 } else {
531 argv = reflect.New(mtype.ArgType)
532 argIsValue = true
533 }
534
535 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
536 return
537 }
538 if argIsValue {
539 argv = argv.Elem()
540 }
541
542 replyv = reflect.New(mtype.ReplyType.Elem())
543 return
544 }
545
546 func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err error) {
547
548 req = server.getRequest()
549 err = codec.ReadRequestHeader(req)
550 if err != nil {
551 req = nil
552 if err == io.EOF || err == io.ErrUnexpectedEOF {
553 return
554 }
555 err = errors.New("rpc: server cannot decode request: " + err.Error())
556 return
557 }
558
559
560
561 keepReading = true
562
563 serviceMethod := strings.Split(req.ServiceMethod, ".")
564 if len(serviceMethod) != 2 {
565 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
566 return
567 }
568
569 server.mu.RLock()
570 service = server.serviceMap[serviceMethod[0]]
571 server.mu.RUnlock()
572 if service == nil {
573 err = errors.New("rpc: can't find service " + req.ServiceMethod)
574 return
575 }
576 mtype = service.method[serviceMethod[1]]
577 if mtype == nil {
578 err = errors.New("rpc: can't find method " + req.ServiceMethod)
579 }
580 return
581 }
582
583
584
585
586 func (server *Server) Accept(lis net.Listener) {
587 for {
588 conn, err := lis.Accept()
589 if err != nil {
590 log.Fatal("rpc.Serve: accept:", err.Error())
591 }
592 go server.ServeConn(conn)
593 }
594 }
595
596
597 func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
598
599
600
601 func RegisterName(name string, rcvr interface{}) error {
602 return DefaultServer.RegisterName(name, rcvr)
603 }
604
605
606
607
608
609
610
611
612 type ServerCodec interface {
613 ReadRequestHeader(*Request) error
614 ReadRequestBody(interface{}) error
615 WriteResponse(*Response, interface{}) error
616
617 Close() error
618 }
619
620
621
622
623
624
625 func ServeConn(conn io.ReadWriteCloser) {
626 DefaultServer.ServeConn(conn)
627 }
628
629
630
631 func ServeCodec(codec ServerCodec) {
632 DefaultServer.ServeCodec(codec)
633 }
634
635
636
637 func ServeRequest(codec ServerCodec) error {
638 return DefaultServer.ServeRequest(codec)
639 }
640
641
642
643
644 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
645
646
647 var connected = "200 Connected to Go RPC"
648
649
650 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
651 if req.Method != "CONNECT" {
652 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
653 w.WriteHeader(http.StatusMethodNotAllowed)
654 io.WriteString(w, "405 must CONNECT\n")
655 return
656 }
657 conn, _, err := w.(http.Hijacker).Hijack()
658 if err != nil {
659 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
660 return
661 }
662 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
663 server.ServeConn(conn)
664 }
665
666
667
668
669 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
670 http.Handle(rpcPath, server)
671 http.Handle(debugPath, debugHTTP{server})
672 }
673
674
675
676
677 func HandleHTTP() {
678 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
679 }
View as plain text