Source file src/net/rpc/jsonrpc/all_test.go

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc
     6  
     7  import (
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"net/rpc"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  )
    18  
    19  type Args struct {
    20  	A, B int
    21  }
    22  
    23  type Reply struct {
    24  	C int
    25  }
    26  
    27  type Arith int
    28  
    29  type ArithAddResp struct {
    30  	Id     any   `json:"id"`
    31  	Result Reply `json:"result"`
    32  	Error  any   `json:"error"`
    33  }
    34  
    35  func (t *Arith) Add(args *Args, reply *Reply) error {
    36  	reply.C = args.A + args.B
    37  	return nil
    38  }
    39  
    40  func (t *Arith) Mul(args *Args, reply *Reply) error {
    41  	reply.C = args.A * args.B
    42  	return nil
    43  }
    44  
    45  func (t *Arith) Div(args *Args, reply *Reply) error {
    46  	if args.B == 0 {
    47  		return errors.New("divide by zero")
    48  	}
    49  	reply.C = args.A / args.B
    50  	return nil
    51  }
    52  
    53  func (t *Arith) Error(args *Args, reply *Reply) error {
    54  	panic("ERROR")
    55  }
    56  
    57  type BuiltinTypes struct{}
    58  
    59  func (BuiltinTypes) Map(i int, reply *map[int]int) error {
    60  	(*reply)[i] = i
    61  	return nil
    62  }
    63  
    64  func (BuiltinTypes) Slice(i int, reply *[]int) error {
    65  	*reply = append(*reply, i)
    66  	return nil
    67  }
    68  
    69  func (BuiltinTypes) Array(i int, reply *[1]int) error {
    70  	(*reply)[0] = i
    71  	return nil
    72  }
    73  
    74  func init() {
    75  	rpc.Register(new(Arith))
    76  	rpc.Register(BuiltinTypes{})
    77  }
    78  
    79  func TestServerNoParams(t *testing.T) {
    80  	cli, srv := net.Pipe()
    81  	defer cli.Close()
    82  	go ServeConn(srv)
    83  	dec := json.NewDecoder(cli)
    84  
    85  	fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
    86  	var resp ArithAddResp
    87  	if err := dec.Decode(&resp); err != nil {
    88  		t.Fatalf("Decode after no params: %s", err)
    89  	}
    90  	if resp.Error == nil {
    91  		t.Fatalf("Expected error, got nil")
    92  	}
    93  }
    94  
    95  func TestServerEmptyMessage(t *testing.T) {
    96  	cli, srv := net.Pipe()
    97  	defer cli.Close()
    98  	go ServeConn(srv)
    99  	dec := json.NewDecoder(cli)
   100  
   101  	fmt.Fprintf(cli, "{}")
   102  	var resp ArithAddResp
   103  	if err := dec.Decode(&resp); err != nil {
   104  		t.Fatalf("Decode after empty: %s", err)
   105  	}
   106  	if resp.Error == nil {
   107  		t.Fatalf("Expected error, got nil")
   108  	}
   109  }
   110  
   111  func TestServer(t *testing.T) {
   112  	cli, srv := net.Pipe()
   113  	defer cli.Close()
   114  	go ServeConn(srv)
   115  	dec := json.NewDecoder(cli)
   116  
   117  	// Send hand-coded requests to server, parse responses.
   118  	for i := 0; i < 10; i++ {
   119  		fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
   120  		var resp ArithAddResp
   121  		err := dec.Decode(&resp)
   122  		if err != nil {
   123  			t.Fatalf("Decode: %s", err)
   124  		}
   125  		if resp.Error != nil {
   126  			t.Fatalf("resp.Error: %s", resp.Error)
   127  		}
   128  		if resp.Id.(string) != string(rune(i)) {
   129  			t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(rune(i)))
   130  		}
   131  		if resp.Result.C != 2*i+1 {
   132  			t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
   133  		}
   134  	}
   135  }
   136  
   137  func TestClient(t *testing.T) {
   138  	// Assume server is okay (TestServer is above).
   139  	// Test client against server.
   140  	cli, srv := net.Pipe()
   141  	go ServeConn(srv)
   142  
   143  	client := NewClient(cli)
   144  	defer client.Close()
   145  
   146  	// Synchronous calls
   147  	args := &Args{7, 8}
   148  	reply := new(Reply)
   149  	err := client.Call("Arith.Add", args, reply)
   150  	if err != nil {
   151  		t.Errorf("Add: expected no error but got string %q", err.Error())
   152  	}
   153  	if reply.C != args.A+args.B {
   154  		t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
   155  	}
   156  
   157  	args = &Args{7, 8}
   158  	reply = new(Reply)
   159  	err = client.Call("Arith.Mul", args, reply)
   160  	if err != nil {
   161  		t.Errorf("Mul: expected no error but got string %q", err.Error())
   162  	}
   163  	if reply.C != args.A*args.B {
   164  		t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
   165  	}
   166  
   167  	// Out of order.
   168  	args = &Args{7, 8}
   169  	mulReply := new(Reply)
   170  	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
   171  	addReply := new(Reply)
   172  	addCall := client.Go("Arith.Add", args, addReply, nil)
   173  
   174  	addCall = <-addCall.Done
   175  	if addCall.Error != nil {
   176  		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
   177  	}
   178  	if addReply.C != args.A+args.B {
   179  		t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
   180  	}
   181  
   182  	mulCall = <-mulCall.Done
   183  	if mulCall.Error != nil {
   184  		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
   185  	}
   186  	if mulReply.C != args.A*args.B {
   187  		t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
   188  	}
   189  
   190  	// Error test
   191  	args = &Args{7, 0}
   192  	reply = new(Reply)
   193  	err = client.Call("Arith.Div", args, reply)
   194  	// expect an error: zero divide
   195  	if err == nil {
   196  		t.Error("Div: expected error")
   197  	} else if err.Error() != "divide by zero" {
   198  		t.Error("Div: expected divide by zero error; got", err)
   199  	}
   200  }
   201  
   202  func TestBuiltinTypes(t *testing.T) {
   203  	cli, srv := net.Pipe()
   204  	go ServeConn(srv)
   205  
   206  	client := NewClient(cli)
   207  	defer client.Close()
   208  
   209  	// Map
   210  	arg := 7
   211  	replyMap := map[int]int{}
   212  	err := client.Call("BuiltinTypes.Map", arg, &replyMap)
   213  	if err != nil {
   214  		t.Errorf("Map: expected no error but got string %q", err.Error())
   215  	}
   216  	if replyMap[arg] != arg {
   217  		t.Errorf("Map: expected %d got %d", arg, replyMap[arg])
   218  	}
   219  
   220  	// Slice
   221  	replySlice := []int{}
   222  	err = client.Call("BuiltinTypes.Slice", arg, &replySlice)
   223  	if err != nil {
   224  		t.Errorf("Slice: expected no error but got string %q", err.Error())
   225  	}
   226  	if e := []int{arg}; !reflect.DeepEqual(replySlice, e) {
   227  		t.Errorf("Slice: expected %v got %v", e, replySlice)
   228  	}
   229  
   230  	// Array
   231  	replyArray := [1]int{}
   232  	err = client.Call("BuiltinTypes.Array", arg, &replyArray)
   233  	if err != nil {
   234  		t.Errorf("Array: expected no error but got string %q", err.Error())
   235  	}
   236  	if e := [1]int{arg}; !reflect.DeepEqual(replyArray, e) {
   237  		t.Errorf("Array: expected %v got %v", e, replyArray)
   238  	}
   239  }
   240  
   241  func TestMalformedInput(t *testing.T) {
   242  	cli, srv := net.Pipe()
   243  	go cli.Write([]byte(`{id:1}`)) // invalid json
   244  	ServeConn(srv)                 // must return, not loop
   245  }
   246  
   247  func TestMalformedOutput(t *testing.T) {
   248  	cli, srv := net.Pipe()
   249  	go srv.Write([]byte(`{"id":0,"result":null,"error":null}`))
   250  	go io.ReadAll(srv)
   251  
   252  	client := NewClient(cli)
   253  	defer client.Close()
   254  
   255  	args := &Args{7, 8}
   256  	reply := new(Reply)
   257  	err := client.Call("Arith.Add", args, reply)
   258  	if err == nil {
   259  		t.Error("expected error")
   260  	}
   261  }
   262  
   263  func TestServerErrorHasNullResult(t *testing.T) {
   264  	var out strings.Builder
   265  	sc := NewServerCodec(struct {
   266  		io.Reader
   267  		io.Writer
   268  		io.Closer
   269  	}{
   270  		Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`),
   271  		Writer: &out,
   272  		Closer: io.NopCloser(nil),
   273  	})
   274  	r := new(rpc.Request)
   275  	if err := sc.ReadRequestHeader(r); err != nil {
   276  		t.Fatal(err)
   277  	}
   278  	const valueText = "the value we don't want to see"
   279  	const errorText = "some error"
   280  	err := sc.WriteResponse(&rpc.Response{
   281  		ServiceMethod: "Method",
   282  		Seq:           1,
   283  		Error:         errorText,
   284  	}, valueText)
   285  	if err != nil {
   286  		t.Fatal(err)
   287  	}
   288  	if !strings.Contains(out.String(), errorText) {
   289  		t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out)
   290  	}
   291  	if strings.Contains(out.String(), valueText) {
   292  		t.Errorf("Response contains both an error and value: %s", &out)
   293  	}
   294  }
   295  
   296  func TestUnexpectedError(t *testing.T) {
   297  	cli, srv := myPipe()
   298  	go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
   299  	ServeConn(srv)                                                    // must return, not loop
   300  }
   301  
   302  // Copied from package net.
   303  func myPipe() (*pipe, *pipe) {
   304  	r1, w1 := io.Pipe()
   305  	r2, w2 := io.Pipe()
   306  
   307  	return &pipe{r1, w2}, &pipe{r2, w1}
   308  }
   309  
   310  type pipe struct {
   311  	*io.PipeReader
   312  	*io.PipeWriter
   313  }
   314  
   315  type pipeAddr int
   316  
   317  func (pipeAddr) Network() string {
   318  	return "pipe"
   319  }
   320  
   321  func (pipeAddr) String() string {
   322  	return "pipe"
   323  }
   324  
   325  func (p *pipe) Close() error {
   326  	err := p.PipeReader.Close()
   327  	err1 := p.PipeWriter.Close()
   328  	if err == nil {
   329  		err = err1
   330  	}
   331  	return err
   332  }
   333  
   334  func (p *pipe) LocalAddr() net.Addr {
   335  	return pipeAddr(0)
   336  }
   337  
   338  func (p *pipe) RemoteAddr() net.Addr {
   339  	return pipeAddr(0)
   340  }
   341  
   342  func (p *pipe) SetTimeout(nsec int64) error {
   343  	return errors.New("net.Pipe does not support timeouts")
   344  }
   345  
   346  func (p *pipe) SetReadTimeout(nsec int64) error {
   347  	return errors.New("net.Pipe does not support timeouts")
   348  }
   349  
   350  func (p *pipe) SetWriteTimeout(nsec int64) error {
   351  	return errors.New("net.Pipe does not support timeouts")
   352  }
   353  

View as plain text