Source file
src/net/rpc/server.go
Documentation: net/rpc
1
2
3
4
5
127 package rpc
128
129 import (
130 "bufio"
131 "encoding/gob"
132 "errors"
133 "io"
134 "log"
135 "net"
136 "net/http"
137 "reflect"
138 "strings"
139 "sync"
140 "unicode"
141 "unicode/utf8"
142 )
143
144 const (
145
146 DefaultRPCPath = "/_goRPC_"
147 DefaultDebugPath = "/debug/rpc"
148 )
149
150
151
152 var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
153
154 type methodType struct {
155 sync.Mutex
156 method reflect.Method
157 ArgType reflect.Type
158 ReplyType reflect.Type
159 numCalls uint
160 }
161
162 type service struct {
163 name string
164 rcvr reflect.Value
165 typ reflect.Type
166 method map[string]*methodType
167 }
168
169
170
171
172 type Request struct {
173 ServiceMethod string
174 Seq uint64
175 next *Request
176 }
177
178
179
180
181 type Response struct {
182 ServiceMethod string
183 Seq uint64
184 Error string
185 next *Response
186 }
187
188
189 type Server struct {
190 serviceMap sync.Map
191 reqLock sync.Mutex
192 freeReq *Request
193 respLock sync.Mutex
194 freeResp *Response
195 }
196
197
198 func NewServer() *Server {
199 return &Server{}
200 }
201
202
203 var DefaultServer = NewServer()
204
205
206 func isExported(name string) bool {
207 rune, _ := utf8.DecodeRuneInString(name)
208 return unicode.IsUpper(rune)
209 }
210
211
212 func isExportedOrBuiltinType(t reflect.Type) bool {
213 for t.Kind() == reflect.Ptr {
214 t = t.Elem()
215 }
216
217
218 return isExported(t.Name()) || t.PkgPath() == ""
219 }
220
221
222
223
224
225
226
227
228
229
230
231 func (server *Server) Register(rcvr interface{}) error {
232 return server.register(rcvr, "", false)
233 }
234
235
236
237 func (server *Server) RegisterName(name string, rcvr interface{}) error {
238 return server.register(rcvr, name, true)
239 }
240
241 func (server *Server) register(rcvr interface{}, name string, useName bool) error {
242 s := new(service)
243 s.typ = reflect.TypeOf(rcvr)
244 s.rcvr = reflect.ValueOf(rcvr)
245 sname := reflect.Indirect(s.rcvr).Type().Name()
246 if useName {
247 sname = name
248 }
249 if sname == "" {
250 s := "rpc.Register: no service name for type " + s.typ.String()
251 log.Print(s)
252 return errors.New(s)
253 }
254 if !isExported(sname) && !useName {
255 s := "rpc.Register: type " + sname + " is not exported"
256 log.Print(s)
257 return errors.New(s)
258 }
259 s.name = sname
260
261
262 s.method = suitableMethods(s.typ, true)
263
264 if len(s.method) == 0 {
265 str := ""
266
267
268 method := suitableMethods(reflect.PtrTo(s.typ), false)
269 if len(method) != 0 {
270 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
271 } else {
272 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
273 }
274 log.Print(str)
275 return errors.New(str)
276 }
277
278 if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
279 return errors.New("rpc: service already defined: " + sname)
280 }
281 return nil
282 }
283
284
285
286 func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
287 methods := make(map[string]*methodType)
288 for m := 0; m < typ.NumMethod(); m++ {
289 method := typ.Method(m)
290 mtype := method.Type
291 mname := method.Name
292
293 if method.PkgPath != "" {
294 continue
295 }
296
297 if mtype.NumIn() != 3 {
298 if reportErr {
299 log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
300 }
301 continue
302 }
303
304 argType := mtype.In(1)
305 if !isExportedOrBuiltinType(argType) {
306 if reportErr {
307 log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
308 }
309 continue
310 }
311
312 replyType := mtype.In(2)
313 if replyType.Kind() != reflect.Ptr {
314 if reportErr {
315 log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
316 }
317 continue
318 }
319
320 if !isExportedOrBuiltinType(replyType) {
321 if reportErr {
322 log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
323 }
324 continue
325 }
326
327 if mtype.NumOut() != 1 {
328 if reportErr {
329 log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
330 }
331 continue
332 }
333
334 if returnType := mtype.Out(0); returnType != typeOfError {
335 if reportErr {
336 log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
337 }
338 continue
339 }
340 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
341 }
342 return methods
343 }
344
345
346
347
348 var invalidRequest = struct{}{}
349
350 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
351 resp := server.getResponse()
352
353 resp.ServiceMethod = req.ServiceMethod
354 if errmsg != "" {
355 resp.Error = errmsg
356 reply = invalidRequest
357 }
358 resp.Seq = req.Seq
359 sending.Lock()
360 err := codec.WriteResponse(resp, reply)
361 if debugLog && err != nil {
362 log.Println("rpc: writing response:", err)
363 }
364 sending.Unlock()
365 server.freeResponse(resp)
366 }
367
368 func (m *methodType) NumCalls() (n uint) {
369 m.Lock()
370 n = m.numCalls
371 m.Unlock()
372 return n
373 }
374
375 func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
376 if wg != nil {
377 defer wg.Done()
378 }
379 mtype.Lock()
380 mtype.numCalls++
381 mtype.Unlock()
382 function := mtype.method.Func
383
384 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
385
386 errInter := returnValues[0].Interface()
387 errmsg := ""
388 if errInter != nil {
389 errmsg = errInter.(error).Error()
390 }
391 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
392 server.freeRequest(req)
393 }
394
395 type gobServerCodec struct {
396 rwc io.ReadWriteCloser
397 dec *gob.Decoder
398 enc *gob.Encoder
399 encBuf *bufio.Writer
400 closed bool
401 }
402
403 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
404 return c.dec.Decode(r)
405 }
406
407 func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
408 return c.dec.Decode(body)
409 }
410
411 func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) {
412 if err = c.enc.Encode(r); err != nil {
413 if c.encBuf.Flush() == nil {
414
415
416 log.Println("rpc: gob error encoding response:", err)
417 c.Close()
418 }
419 return
420 }
421 if err = c.enc.Encode(body); err != nil {
422 if c.encBuf.Flush() == nil {
423
424
425 log.Println("rpc: gob error encoding body:", err)
426 c.Close()
427 }
428 return
429 }
430 return c.encBuf.Flush()
431 }
432
433 func (c *gobServerCodec) Close() error {
434 if c.closed {
435
436 return nil
437 }
438 c.closed = true
439 return c.rwc.Close()
440 }
441
442
443
444
445
446
447
448 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
449 buf := bufio.NewWriter(conn)
450 srv := &gobServerCodec{
451 rwc: conn,
452 dec: gob.NewDecoder(conn),
453 enc: gob.NewEncoder(buf),
454 encBuf: buf,
455 }
456 server.ServeCodec(srv)
457 }
458
459
460
461 func (server *Server) ServeCodec(codec ServerCodec) {
462 sending := new(sync.Mutex)
463 wg := new(sync.WaitGroup)
464 for {
465 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
466 if err != nil {
467 if debugLog && err != io.EOF {
468 log.Println("rpc:", err)
469 }
470 if !keepReading {
471 break
472 }
473
474 if req != nil {
475 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
476 server.freeRequest(req)
477 }
478 continue
479 }
480 wg.Add(1)
481 go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
482 }
483
484
485 wg.Wait()
486 codec.Close()
487 }
488
489
490
491 func (server *Server) ServeRequest(codec ServerCodec) error {
492 sending := new(sync.Mutex)
493 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
494 if err != nil {
495 if !keepReading {
496 return err
497 }
498
499 if req != nil {
500 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
501 server.freeRequest(req)
502 }
503 return err
504 }
505 service.call(server, sending, nil, mtype, req, argv, replyv, codec)
506 return nil
507 }
508
509 func (server *Server) getRequest() *Request {
510 server.reqLock.Lock()
511 req := server.freeReq
512 if req == nil {
513 req = new(Request)
514 } else {
515 server.freeReq = req.next
516 *req = Request{}
517 }
518 server.reqLock.Unlock()
519 return req
520 }
521
522 func (server *Server) freeRequest(req *Request) {
523 server.reqLock.Lock()
524 req.next = server.freeReq
525 server.freeReq = req
526 server.reqLock.Unlock()
527 }
528
529 func (server *Server) getResponse() *Response {
530 server.respLock.Lock()
531 resp := server.freeResp
532 if resp == nil {
533 resp = new(Response)
534 } else {
535 server.freeResp = resp.next
536 *resp = Response{}
537 }
538 server.respLock.Unlock()
539 return resp
540 }
541
542 func (server *Server) freeResponse(resp *Response) {
543 server.respLock.Lock()
544 resp.next = server.freeResp
545 server.freeResp = resp
546 server.respLock.Unlock()
547 }
548
549 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
550 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
551 if err != nil {
552 if !keepReading {
553 return
554 }
555
556 codec.ReadRequestBody(nil)
557 return
558 }
559
560
561 argIsValue := false
562 if mtype.ArgType.Kind() == reflect.Ptr {
563 argv = reflect.New(mtype.ArgType.Elem())
564 } else {
565 argv = reflect.New(mtype.ArgType)
566 argIsValue = true
567 }
568
569 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
570 return
571 }
572 if argIsValue {
573 argv = argv.Elem()
574 }
575
576 replyv = reflect.New(mtype.ReplyType.Elem())
577
578 switch mtype.ReplyType.Elem().Kind() {
579 case reflect.Map:
580 replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
581 case reflect.Slice:
582 replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
583 }
584 return
585 }
586
587 func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
588
589 req = server.getRequest()
590 err = codec.ReadRequestHeader(req)
591 if err != nil {
592 req = nil
593 if err == io.EOF || err == io.ErrUnexpectedEOF {
594 return
595 }
596 err = errors.New("rpc: server cannot decode request: " + err.Error())
597 return
598 }
599
600
601
602 keepReading = true
603
604 dot := strings.LastIndex(req.ServiceMethod, ".")
605 if dot < 0 {
606 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
607 return
608 }
609 serviceName := req.ServiceMethod[:dot]
610 methodName := req.ServiceMethod[dot+1:]
611
612
613 svci, ok := server.serviceMap.Load(serviceName)
614 if !ok {
615 err = errors.New("rpc: can't find service " + req.ServiceMethod)
616 return
617 }
618 svc = svci.(*service)
619 mtype = svc.method[methodName]
620 if mtype == nil {
621 err = errors.New("rpc: can't find method " + req.ServiceMethod)
622 }
623 return
624 }
625
626
627
628
629
630 func (server *Server) Accept(lis net.Listener) {
631 for {
632 conn, err := lis.Accept()
633 if err != nil {
634 log.Print("rpc.Serve: accept:", err.Error())
635 return
636 }
637 go server.ServeConn(conn)
638 }
639 }
640
641
642 func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
643
644
645
646 func RegisterName(name string, rcvr interface{}) error {
647 return DefaultServer.RegisterName(name, rcvr)
648 }
649
650
651
652
653
654
655
656
657
658 type ServerCodec interface {
659 ReadRequestHeader(*Request) error
660 ReadRequestBody(interface{}) error
661 WriteResponse(*Response, interface{}) error
662
663
664 Close() error
665 }
666
667
668
669
670
671
672
673 func ServeConn(conn io.ReadWriteCloser) {
674 DefaultServer.ServeConn(conn)
675 }
676
677
678
679 func ServeCodec(codec ServerCodec) {
680 DefaultServer.ServeCodec(codec)
681 }
682
683
684
685 func ServeRequest(codec ServerCodec) error {
686 return DefaultServer.ServeRequest(codec)
687 }
688
689
690
691
692 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
693
694
695 var connected = "200 Connected to Go RPC"
696
697
698 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
699 if req.Method != "CONNECT" {
700 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
701 w.WriteHeader(http.StatusMethodNotAllowed)
702 io.WriteString(w, "405 must CONNECT\n")
703 return
704 }
705 conn, _, err := w.(http.Hijacker).Hijack()
706 if err != nil {
707 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
708 return
709 }
710 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
711 server.ServeConn(conn)
712 }
713
714
715
716
717 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
718 http.Handle(rpcPath, server)
719 http.Handle(debugPath, debugHTTP{server})
720 }
721
722
723
724
725 func HandleHTTP() {
726 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
727 }
728
View as plain text