1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 #include "go.h"
6
7 enum
8 {
9 Snorm = 0,
10 Strue,
11 Sfalse,
12 Stype,
13
14 Tdefault, // default case
15 Texprconst, // normal constant case
16 Texprvar, // normal variable case
17 Ttypenil, // case nil
18 Ttypeconst, // type hashes
19 Ttypevar, // interface type
20
21 Ncase = 4, // count needed to split
22 };
23
24 typedef struct Case Case;
25 struct Case
26 {
27 Node* node; // points at case statement
28 uint32 hash; // hash of a type switch
29 uint8 type; // type of case
30 uint8 diag; // suppress multiple diagnostics
31 uint16 ordinal; // position in switch
32 Case* link; // linked list to link
33 };
34 #define C ((Case*)nil)
35
36 void
37 dumpcase(Case *c0)
38 {
39 Case *c;
40
41 for(c=c0; c!=C; c=c->link) {
42 switch(c->type) {
43 case Tdefault:
44 print("case-default\n");
45 print(" ord=%d\n", c->ordinal);
46 break;
47 case Texprconst:
48 print("case-exprconst\n");
49 print(" ord=%d\n", c->ordinal);
50 break;
51 case Texprvar:
52 print("case-exprvar\n");
53 print(" ord=%d\n", c->ordinal);
54 print(" op=%O\n", c->node->left->op);
55 break;
56 case Ttypenil:
57 print("case-typenil\n");
58 print(" ord=%d\n", c->ordinal);
59 break;
60 case Ttypeconst:
61 print("case-typeconst\n");
62 print(" ord=%d\n", c->ordinal);
63 print(" hash=%ux\n", c->hash);
64 break;
65 case Ttypevar:
66 print("case-typevar\n");
67 print(" ord=%d\n", c->ordinal);
68 break;
69 default:
70 print("case-???\n");
71 print(" ord=%d\n", c->ordinal);
72 print(" op=%O\n", c->node->left->op);
73 print(" hash=%ux\n", c->hash);
74 break;
75 }
76 }
77 print("\n");
78 }
79
80 static int
81 ordlcmp(Case *c1, Case *c2)
82 {
83 // sort default first
84 if(c1->type == Tdefault)
85 return -1;
86 if(c2->type == Tdefault)
87 return +1;
88
89 // sort nil second
90 if(c1->type == Ttypenil)
91 return -1;
92 if(c2->type == Ttypenil)
93 return +1;
94
95 // sort by ordinal
96 if(c1->ordinal > c2->ordinal)
97 return +1;
98 if(c1->ordinal < c2->ordinal)
99 return -1;
100 return 0;
101 }
102
103 static int
104 exprcmp(Case *c1, Case *c2)
105 {
106 int ct, n;
107 Node *n1, *n2;
108
109 // sort non-constants last
110 if(c1->type != Texprconst)
111 return +1;
112 if(c2->type != Texprconst)
113 return -1;
114
115 n1 = c1->node->left;
116 n2 = c2->node->left;
117
118 ct = n1->val.ctype;
119 if(ct != n2->val.ctype) {
120 // invalid program, but return a sort
121 // order so that we can give a better
122 // error later.
123 return ct - n2->val.ctype;
124 }
125
126 // sort by constant value
127 n = 0;
128 switch(ct) {
129 case CTFLT:
130 n = mpcmpfltflt(n1->val.u.fval, n2->val.u.fval);
131 break;
132 case CTINT:
133 n = mpcmpfixfix(n1->val.u.xval, n2->val.u.xval);
134 break;
135 case CTSTR:
136 n = cmpslit(n1, n2);
137 break;
138 }
139
140 return n;
141 }
142
143 static int
144 typecmp(Case *c1, Case *c2)
145 {
146
147 // sort non-constants last
148 if(c1->type != Ttypeconst)
149 return +1;
150 if(c2->type != Ttypeconst)
151 return -1;
152
153 // sort by hash code
154 if(c1->hash > c2->hash)
155 return +1;
156 if(c1->hash < c2->hash)
157 return -1;
158
159 // sort by ordinal so duplicate error
160 // happens on later case.
161 if(c1->ordinal > c2->ordinal)
162 return +1;
163 if(c1->ordinal < c2->ordinal)
164 return -1;
165 return 0;
166 }
167
168 static Case*
169 csort(Case *l, int(*f)(Case*, Case*))
170 {
171 Case *l1, *l2, *le;
172
173 if(l == C || l->link == C)
174 return l;
175
176 l1 = l;
177 l2 = l;
178 for(;;) {
179 l2 = l2->link;
180 if(l2 == C)
181 break;
182 l2 = l2->link;
183 if(l2 == C)
184 break;
185 l1 = l1->link;
186 }
187
188 l2 = l1->link;
189 l1->link = C;
190 l1 = csort(l, f);
191 l2 = csort(l2, f);
192
193 /* set up lead element */
194 if((*f)(l1, l2) < 0) {
195 l = l1;
196 l1 = l1->link;
197 } else {
198 l = l2;
199 l2 = l2->link;
200 }
201 le = l;
202
203 for(;;) {
204 if(l1 == C) {
205 while(l2) {
206 le->link = l2;
207 le = l2;
208 l2 = l2->link;
209 }
210 le->link = C;
211 break;
212 }
213 if(l2 == C) {
214 while(l1) {
215 le->link = l1;
216 le = l1;
217 l1 = l1->link;
218 }
219 break;
220 }
221 if((*f)(l1, l2) < 0) {
222 le->link = l1;
223 le = l1;
224 l1 = l1->link;
225 } else {
226 le->link = l2;
227 le = l2;
228 l2 = l2->link;
229 }
230 }
231 le->link = C;
232 return l;
233 }
234
235 static Node*
236 newlabel(void)
237 {
238 static int label;
239
240 label++;
241 snprint(namebuf, sizeof(namebuf), "%.6d", label);
242 return newname(lookup(namebuf));
243 }
244
245 /*
246 * build separate list of statements and cases
247 * make labels between cases and statements
248 * deal with fallthrough, break, unreachable statements
249 */
250 static void
251 casebody(Node *sw, Node *typeswvar)
252 {
253 Node *n, *c, *last;
254 Node *def;
255 NodeList *cas, *stat, *l, *lc;
256 Node *go, *br;
257 int32 lno, needvar;
258
259 lno = setlineno(sw);
260 if(sw->list == nil)
261 return;
262
263 cas = nil; // cases
264 stat = nil; // statements
265 def = N; // defaults
266 br = nod(OBREAK, N, N);
267
268 for(l=sw->list; l; l=l->next) {
269 n = l->n;
270 lno = setlineno(n);
271 if(n->op != OXCASE)
272 fatal("casebody %O", n->op);
273 n->op = OCASE;
274 needvar = count(n->list) != 1 || n->list->n->op == OLITERAL;
275
276 go = nod(OGOTO, newlabel(), N);
277 if(n->list == nil) {
278 if(def != N)
279 yyerror("more than one default case");
280 // reuse original default case
281 n->right = go;
282 def = n;
283 }
284
285 if(n->list != nil && n->list->next == nil) {
286 // one case - reuse OCASE node.
287 c = n->list->n;
288 n->left = c;
289 n->right = go;
290 n->list = nil;
291 cas = list(cas, n);
292 } else {
293 // expand multi-valued cases
294 for(lc=n->list; lc; lc=lc->next) {
295 c = lc->n;
296 cas = list(cas, nod(OCASE, c, go));
297 }
298 }
299
300 stat = list(stat, nod(OLABEL, go->left, N));
301 if(typeswvar && needvar && n->nname != N) {
302 NodeList *l;
303
304 l = list1(nod(ODCL, n->nname, N));
305 l = list(l, nod(OAS, n->nname, typeswvar));
306 typechecklist(l, Etop);
307 stat = concat(stat, l);
308 }
309 stat = concat(stat, n->nbody);
310
311 // botch - shouldnt fall thru declaration
312 last = stat->end->n;
313 if(last->op == OXFALL) {
314 if(typeswvar) {
315 setlineno(last);
316 yyerror("cannot fallthrough in type switch");
317 }
318 last->op = OFALL;
319 } else
320 stat = list(stat, br);
321 }
322
323 stat = list(stat, br);
324 if(def)
325 cas = list(cas, def);
326
327 sw->list = cas;
328 sw->nbody = stat;
329 lineno = lno;
330 }
331
332 static Case*
333 mkcaselist(Node *sw, int arg)
334 {
335 Node *n;
336 Case *c, *c1, *c2;
337 NodeList *l;
338 int ord;
339
340 c = C;
341 ord = 0;
342
343 for(l=sw->list; l; l=l->next) {
344 n = l->n;
345 c1 = mal(sizeof(*c1));
346 c1->link = c;
347 c = c1;
348
349 ord++;
350 c->ordinal = ord;
351 c->node = n;
352
353 if(n->left == N) {
354 c->type = Tdefault;
355 continue;
356 }
357
358 switch(arg) {
359 case Stype:
360 c->hash = 0;
361 if(n->left->op == OLITERAL) {
362 c->type = Ttypenil;
363 continue;
364 }
365 if(istype(n->left->type, TINTER)) {
366 c->type = Ttypevar;
367 continue;
368 }
369
370 c->hash = typehash(n->left->type);
371 c->type = Ttypeconst;
372 continue;
373
374 case Snorm:
375 case Strue:
376 case Sfalse:
377 c->type = Texprvar;
378 switch(consttype(n->left)) {
379 case CTFLT:
380 case CTINT:
381 case CTSTR:
382 c->type = Texprconst;
383 }
384 continue;
385 }
386 }
387
388 if(c == C)
389 return C;
390
391 // sort by value and diagnose duplicate cases
392 switch(arg) {
393 case Stype:
394 c = csort(c, typecmp);
395 for(c1=c; c1!=C; c1=c1->link) {
396 for(c2=c1->link; c2!=C && c2->hash==c1->hash; c2=c2->link) {
397 if(c1->type == Ttypenil || c1->type == Tdefault)
398 break;
399 if(c2->type == Ttypenil || c2->type == Tdefault)
400 break;
401 if(!eqtype(c1->node->left->type, c2->node->left->type))
402 continue;
403 yyerrorl(c2->node->lineno, "duplicate case in switch\n\tprevious case at %L", c1->node->lineno);
404 }
405 }
406 break;
407 case Snorm:
408 case Strue:
409 case Sfalse:
410 c = csort(c, exprcmp);
411 for(c1=c; c1->link!=C; c1=c1->link) {
412 if(exprcmp(c1, c1->link) != 0)
413 continue;
414 setlineno(c1->link->node);
415 yyerror("duplicate case in switch\n\tprevious case at %L", c1->node->lineno);
416 }
417 break;
418 }
419
420 // put list back in processing order
421 c = csort(c, ordlcmp);
422 return c;
423 }
424
425 static Node* exprname;
426
427 static Node*
428 exprbsw(Case *c0, int ncase, int arg)
429 {
430 NodeList *cas;
431 Node *a, *n;
432 Case *c;
433 int i, half, lno;
434
435 cas = nil;
436 if(ncase < Ncase) {
437 for(i=0; i<ncase; i++) {
438 n = c0->node;
439 lno = setlineno(n);
440
441 switch(arg) {
442 case Strue:
443 a = nod(OIF, N, N);
444 a->ntest = n->left; // if val
445 a->nbody = list1(n->right); // then goto l
446 break;
447
448 case Sfalse:
449 a = nod(OIF, N, N);
450 a->ntest = nod(ONOT, n->left, N); // if !val
451 typecheck(&a->ntest, Erv);
452 a->nbody = list1(n->right); // then goto l
453 break;
454
455 default:
456 a = nod(OIF, N, N);
457 a->ntest = nod(OEQ, exprname, n->left); // if name == val
458 typecheck(&a->ntest, Erv);
459 a->nbody = list1(n->right); // then goto l
460 break;
461 }
462
463 cas = list(cas, a);
464 c0 = c0->link;
465 lineno = lno;
466 }
467 return liststmt(cas);
468 }
469
470 // find the middle and recur
471 c = c0;
472 half = ncase>>1;
473 for(i=1; i<half; i++)
474 c = c->link;
475 a = nod(OIF, N, N);
476 a->ntest = nod(OLE, exprname, c->node->left);
477 typecheck(&a->ntest, Erv);
478 a->nbody = list1(exprbsw(c0, half, arg));
479 a->nelse = list1(exprbsw(c->link, ncase-half, arg));
480 return a;
481 }
482
483 /*
484 * normal (expression) switch.
485 * rebulid case statements into if .. goto
486 */
487 static void
488 exprswitch(Node *sw)
489 {
490 Node *def;
491 NodeList *cas;
492 Node *a;
493 Case *c0, *c, *c1;
494 Type *t;
495 int arg, ncase;
496
497 casebody(sw, N);
498
499 arg = Snorm;
500 if(isconst(sw->ntest, CTBOOL)) {
501 arg = Strue;
502 if(sw->ntest->val.u.bval == 0)
503 arg = Sfalse;
504 }
505 walkexpr(&sw->ntest, &sw->ninit);
506 t = sw->type;
507 if(t == T)
508 return;
509
510 /*
511 * convert the switch into OIF statements
512 */
513 exprname = N;
514 cas = nil;
515 if(arg != Strue && arg != Sfalse) {
516 exprname = nod(OXXX, N, N);
517 tempname(exprname, sw->ntest->type);
518 cas = list1(nod(OAS, exprname, sw->ntest));
519 typechecklist(cas, Etop);
520 }
521
522 c0 = mkcaselist(sw, arg);
523 if(c0 != C && c0->type == Tdefault) {
524 def = c0->node->right;
525 c0 = c0->link;
526 } else {
527 def = nod(OBREAK, N, N);
528 }
529
530 loop:
531 if(c0 == C) {
532 cas = list(cas, def);
533 sw->nbody = concat(cas, sw->nbody);
534 sw->list = nil;
535 walkstmtlist(sw->nbody);
536 return;
537 }
538
539 // deal with the variables one-at-a-time
540 if(c0->type != Texprconst) {
541 a = exprbsw(c0, 1, arg);
542 cas = list(cas, a);
543 c0 = c0->link;
544 goto loop;
545 }
546
547 // do binary search on run of constants
548 ncase = 1;
549 for(c=c0; c->link!=C; c=c->link) {
550 if(c->link->type != Texprconst)
551 break;
552 ncase++;
553 }
554
555 // break the chain at the count
556 c1 = c->link;
557 c->link = C;
558
559 // sort and compile constants
560 c0 = csort(c0, exprcmp);
561 a = exprbsw(c0, ncase, arg);
562 cas = list(cas, a);
563
564 c0 = c1;
565 goto loop;
566
567 }
568
569 static Node* hashname;
570 static Node* facename;
571 static Node* boolname;
572
573 static Node*
574 typeone(Node *t)
575 {
576 NodeList *init;
577 Node *a, *b, *var;
578
579 var = t->nname;
580 init = nil;
581 if(var == N) {
582 typecheck(&nblank, Erv | Easgn);
583 var = nblank;
584 } else
585 init = list1(nod(ODCL, var, N));
586
587 a = nod(OAS2, N, N);
588 a->list = list(list1(var), boolname); // var,bool =
589 b = nod(ODOTTYPE, facename, N);
590 b->type = t->left->type; // interface.(type)
591 a->rlist = list1(b);
592 typecheck(&a, Etop);
593 init = list(init, a);
594
595 b = nod(OIF, N, N);
596 b->ntest = boolname;
597 b->nbody = list1(t->right); // if bool { goto l }
598 a = liststmt(list(init, b));
599 return a;
600 }
601
602 static Node*
603 typebsw(Case *c0, int ncase)
604 {
605 NodeList *cas;
606 Node *a, *n;
607 Case *c;
608 int i, half;
609
610 cas = nil;
611
612 if(ncase < Ncase) {
613 for(i=0; i<ncase; i++) {
614 n = c0->node;
615 if(c0->type != Ttypeconst)
616 fatal("typebsw");
617 a = nod(OIF, N, N);
618 a->ntest = nod(OEQ, hashname, nodintconst(c0->hash));
619 typecheck(&a->ntest, Erv);
620 a->nbody = list1(n->right);
621 cas = list(cas, a);
622 c0 = c0->link;
623 }
624 return liststmt(cas);
625 }
626
627 // find the middle and recur
628 c = c0;
629 half = ncase>>1;
630 for(i=1; i<half; i++)
631 c = c->link;
632 a = nod(OIF, N, N);
633 a->ntest = nod(OLE, hashname, nodintconst(c->hash));
634 typecheck(&a->ntest, Erv);
635 a->nbody = list1(typebsw(c0, half));
636 a->nelse = list1(typebsw(c->link, ncase-half));
637 return a;
638 }
639
640 /*
641 * convert switch of the form
642 * switch v := i.(type) { case t1: ..; case t2: ..; }
643 * into if statements
644 */
645 static void
646 typeswitch(Node *sw)
647 {
648 Node *def;
649 NodeList *cas, *hash;
650 Node *a, *n;
651 Case *c, *c0, *c1;
652 int ncase;
653 Type *t;
654 Val v;
655
656 if(sw->ntest == nil)
657 return;
658 if(sw->ntest->right == nil) {
659 setlineno(sw);
660 yyerror("type switch must have an assignment");
661 return;
662 }
663 walkexpr(&sw->ntest->right, &sw->ninit);
664 if(!istype(sw->ntest->right->type, TINTER)) {
665 yyerror("type switch must be on an interface");
666 return;
667 }
668 cas = nil;
669
670 /*
671 * predeclare temporary variables
672 * and the boolean var
673 */
674 facename = nod(OXXX, N, N);
675 tempname(facename, sw->ntest->right->type);
676 a = nod(OAS, facename, sw->ntest->right);
677 typecheck(&a, Etop);
678 cas = list(cas, a);
679
680 casebody(sw, facename);
681
682 boolname = nod(OXXX, N, N);
683 tempname(boolname, types[TBOOL]);
684 typecheck(&boolname, Erv);
685
686 hashname = nod(OXXX, N, N);
687 tempname(hashname, types[TUINT32]);
688 typecheck(&hashname, Erv);
689
690 t = sw->ntest->right->type;
691 if(isnilinter(t))
692 a = syslook("efacethash", 1);
693 else
694 a = syslook("ifacethash", 1);
695 argtype(a, t);
696 a = nod(OCALL, a, N);
697 a->list = list1(facename);
698 a = nod(OAS, hashname, a);
699 typecheck(&a, Etop);
700 cas = list(cas, a);
701
702 c0 = mkcaselist(sw, Stype);
703 if(c0 != C && c0->type == Tdefault) {
704 def = c0->node->right;
705 c0 = c0->link;
706 } else {
707 def = nod(OBREAK, N, N);
708 }
709
710 /*
711 * insert if statement into each case block
712 */
713 for(c=c0; c!=C; c=c->link) {
714 n = c->node;
715 switch(c->type) {
716
717 case Ttypenil:
718 v.ctype = CTNIL;
719 a = nod(OIF, N, N);
720 a->ntest = nod(OEQ, facename, nodlit(v));
721 typecheck(&a->ntest, Erv);
722 a->nbody = list1(n->right); // if i==nil { goto l }
723 n->right = a;
724 break;
725
726 case Ttypevar:
727 case Ttypeconst:
728 n->right = typeone(n);
729 break;
730 }
731 }
732
733 /*
734 * generate list of if statements, binary search for constant sequences
735 */
736 while(c0 != C) {
737 if(c0->type != Ttypeconst) {
738 n = c0->node;
739 cas = list(cas, n->right);
740 c0=c0->link;
741 continue;
742 }
743
744 // identify run of constants
745 c1 = c = c0;
746 while(c->link!=C && c->link->type==Ttypeconst)
747 c = c->link;
748 c0 = c->link;
749 c->link = nil;
750
751 // sort by hash
752 c1 = csort(c1, typecmp);
753
754 // for debugging: linear search
755 if(0) {
756 for(c=c1; c!=C; c=c->link) {
757 n = c->node;
758 cas = list(cas, n->right);
759 }
760 continue;
761 }
762
763 // combine adjacent cases with the same hash
764 ncase = 0;
765 for(c=c1; c!=C; c=c->link) {
766 ncase++;
767 hash = list1(c->node->right);
768 while(c->link != C && c->link->hash == c->hash) {
769 hash = list(hash, c->link->node->right);
770 c->link = c->link->link;
771 }
772 c->node->right = liststmt(hash);
773 }
774
775 // binary search among cases to narrow by hash
776 cas = list(cas, typebsw(c1, ncase));
777 }
778 if(nerrors == 0) {
779 cas = list(cas, def);
780 sw->nbody = concat(cas, sw->nbody);
781 sw->list = nil;
782 walkstmtlist(sw->nbody);
783 }
784 }
785
786 void
787 walkswitch(Node *sw)
788 {
789
790 /*
791 * reorder the body into (OLIST, cases, statements)
792 * cases have OGOTO into statements.
793 * both have inserted OBREAK statements
794 */
795 walkstmtlist(sw->ninit);
796 if(sw->ntest == N) {
797 sw->ntest = nodbool(1);
798 typecheck(&sw->ntest, Erv);
799 }
800
801 if(sw->ntest->op == OTYPESW) {
802 typeswitch(sw);
803 //dump("sw", sw);
804 return;
805 }
806 exprswitch(sw);
807 }
808
809 /*
810 * type check switch statement
811 */
812 void
813 typecheckswitch(Node *n)
814 {
815 int top, lno;
816 Type *t;
817 NodeList *l, *ll;
818 Node *ncase, *nvar;
819 Node *def;
820
821 lno = lineno;
822 typechecklist(n->ninit, Etop);
823
824 if(n->ntest != N && n->ntest->op == OTYPESW) {
825 // type switch
826 top = Etype;
827 typecheck(&n->ntest->right, Erv);
828 t = n->ntest->right->type;
829 if(t != T && t->etype != TINTER)
830 yyerror("cannot type switch on non-interface value %+N", n->ntest->right);
831 } else {
832 // value switch
833 top = Erv;
834 if(n->ntest) {
835 typecheck(&n->ntest, Erv);
836 defaultlit(&n->ntest, T);
837 t = n->ntest->type;
838 } else
839 t = types[TBOOL];
840 }
841 n->type = t;
842
843 def = N;
844 for(l=n->list; l; l=l->next) {
845 ncase = l->n;
846 setlineno(n);
847 if(ncase->list == nil) {
848 // default
849 if(def != N)
850 yyerror("multiple defaults in switch (first at %L)", def->lineno);
851 else
852 def = ncase;
853 } else {
854 for(ll=ncase->list; ll; ll=ll->next) {
855 setlineno(ll->n);
856 typecheck(&ll->n, Erv | Etype);
857 if(ll->n->type == T || t == T)
858 continue;
859 switch(top) {
860 case Erv: // expression switch
861 defaultlit(&ll->n, t);
862 if(ll->n->op == OTYPE)
863 yyerror("type %T is not an expression", ll->n->type);
864 else if(ll->n->type != T && !eqtype(ll->n->type, t))
865 yyerror("case %+N in %T switch", ll->n, t);
866 break;
867 case Etype: // type switch
868 if(ll->n->op == OLITERAL && istype(ll->n->type, TNIL))
869 ;
870 else if(ll->n->op != OTYPE && ll->n->type != T) {
871 yyerror("%#N is not a type", ll->n);
872 // reset to original type
873 ll->n = n->ntest->right;
874 }
875 break;
876 }
877 }
878 }
879 if(top == Etype && n->type != T) {
880 ll = ncase->list;
881 nvar = ncase->nname;
882 if(nvar != N) {
883 if(ll && ll->next == nil && ll->n->type != T && !istype(ll->n->type, TNIL)) {
884 // single entry type switch
885 nvar->ntype = typenod(ll->n->type);
886 } else {
887 // multiple entry type switch or default
888 nvar->ntype = typenod(n->type);
889 }
890 }
891 }
892 typechecklist(ncase->nbody, Etop);
893 }
894
895 lineno = lno;
896 }