...
Run Format

Source file src/syscall/mksyscall_windows.go

Documentation: syscall

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build ignore
     6  
     7  /*
     8  mksyscall_windows generates windows system call bodies
     9  
    10  It parses all files specified on command line containing function
    11  prototypes (like syscall_windows.go) and prints system call bodies
    12  to standard output.
    13  
    14  The prototypes are marked by lines beginning with "//sys" and read
    15  like func declarations if //sys is replaced by func, but:
    16  
    17  * The parameter lists must give a name for each argument. This
    18    includes return parameters.
    19  
    20  * The parameter lists must give a type for each argument:
    21    the (x, y, z int) shorthand is not allowed.
    22  
    23  * If the return parameter is an error number, it must be named err.
    24  
    25  * If go func name needs to be different from it's winapi dll name,
    26    the winapi name could be specified at the end, after "=" sign, like
    27    //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
    28  
    29  * Each function that returns err needs to supply a condition, that
    30    return value of winapi will be tested against to detect failure.
    31    This would set err to windows "last-error", otherwise it will be nil.
    32    The value can be provided at end of //sys declaration, like
    33    //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
    34    and is [failretval==0] by default.
    35  
    36  Usage:
    37  	mksyscall_windows [flags] [path ...]
    38  
    39  The flags are:
    40  	-output
    41  		Specify output file name (outputs to console if blank).
    42  	-trace
    43  		Generate print statement after every syscall.
    44  */
    45  package main
    46  
    47  import (
    48  	"bufio"
    49  	"bytes"
    50  	"errors"
    51  	"flag"
    52  	"fmt"
    53  	"go/format"
    54  	"go/parser"
    55  	"go/token"
    56  	"io"
    57  	"io/ioutil"
    58  	"log"
    59  	"os"
    60  	"path/filepath"
    61  	"runtime"
    62  	"sort"
    63  	"strconv"
    64  	"strings"
    65  	"text/template"
    66  )
    67  
    68  var (
    69  	filename       = flag.String("output", "", "output file name (standard output if omitted)")
    70  	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
    71  	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
    72  )
    73  
    74  func trim(s string) string {
    75  	return strings.Trim(s, " \t")
    76  }
    77  
    78  var packageName string
    79  
    80  func packagename() string {
    81  	return packageName
    82  }
    83  
    84  func syscalldot() string {
    85  	if packageName == "syscall" {
    86  		return ""
    87  	}
    88  	return "syscall."
    89  }
    90  
    91  // Param is function parameter
    92  type Param struct {
    93  	Name      string
    94  	Type      string
    95  	fn        *Fn
    96  	tmpVarIdx int
    97  }
    98  
    99  // tmpVar returns temp variable name that will be used to represent p during syscall.
   100  func (p *Param) tmpVar() string {
   101  	if p.tmpVarIdx < 0 {
   102  		p.tmpVarIdx = p.fn.curTmpVarIdx
   103  		p.fn.curTmpVarIdx++
   104  	}
   105  	return fmt.Sprintf("_p%d", p.tmpVarIdx)
   106  }
   107  
   108  // BoolTmpVarCode returns source code for bool temp variable.
   109  func (p *Param) BoolTmpVarCode() string {
   110  	const code = `var %s uint32
   111  	if %s {
   112  		%s = 1
   113  	} else {
   114  		%s = 0
   115  	}`
   116  	tmp := p.tmpVar()
   117  	return fmt.Sprintf(code, tmp, p.Name, tmp, tmp)
   118  }
   119  
   120  // SliceTmpVarCode returns source code for slice temp variable.
   121  func (p *Param) SliceTmpVarCode() string {
   122  	const code = `var %s *%s
   123  	if len(%s) > 0 {
   124  		%s = &%s[0]
   125  	}`
   126  	tmp := p.tmpVar()
   127  	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
   128  }
   129  
   130  // StringTmpVarCode returns source code for string temp variable.
   131  func (p *Param) StringTmpVarCode() string {
   132  	errvar := p.fn.Rets.ErrorVarName()
   133  	if errvar == "" {
   134  		errvar = "_"
   135  	}
   136  	tmp := p.tmpVar()
   137  	const code = `var %s %s
   138  	%s, %s = %s(%s)`
   139  	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
   140  	if errvar == "-" {
   141  		return s
   142  	}
   143  	const morecode = `
   144  	if %s != nil {
   145  		return
   146  	}`
   147  	return s + fmt.Sprintf(morecode, errvar)
   148  }
   149  
   150  // TmpVarCode returns source code for temp variable.
   151  func (p *Param) TmpVarCode() string {
   152  	switch {
   153  	case p.Type == "bool":
   154  		return p.BoolTmpVarCode()
   155  	case strings.HasPrefix(p.Type, "[]"):
   156  		return p.SliceTmpVarCode()
   157  	default:
   158  		return ""
   159  	}
   160  }
   161  
   162  // TmpVarHelperCode returns source code for helper's temp variable.
   163  func (p *Param) TmpVarHelperCode() string {
   164  	if p.Type != "string" {
   165  		return ""
   166  	}
   167  	return p.StringTmpVarCode()
   168  }
   169  
   170  // SyscallArgList returns source code fragments representing p parameter
   171  // in syscall. Slices are translated into 2 syscall parameters: pointer to
   172  // the first element and length.
   173  func (p *Param) SyscallArgList() []string {
   174  	t := p.HelperType()
   175  	var s string
   176  	switch {
   177  	case t[0] == '*':
   178  		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
   179  	case t == "bool":
   180  		s = p.tmpVar()
   181  	case strings.HasPrefix(t, "[]"):
   182  		return []string{
   183  			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
   184  			fmt.Sprintf("uintptr(len(%s))", p.Name),
   185  		}
   186  	default:
   187  		s = p.Name
   188  	}
   189  	return []string{fmt.Sprintf("uintptr(%s)", s)}
   190  }
   191  
   192  // IsError determines if p parameter is used to return error.
   193  func (p *Param) IsError() bool {
   194  	return p.Name == "err" && p.Type == "error"
   195  }
   196  
   197  // HelperType returns type of parameter p used in helper function.
   198  func (p *Param) HelperType() string {
   199  	if p.Type == "string" {
   200  		return p.fn.StrconvType()
   201  	}
   202  	return p.Type
   203  }
   204  
   205  // join concatenates parameters ps into a string with sep separator.
   206  // Each parameter is converted into string by applying fn to it
   207  // before conversion.
   208  func join(ps []*Param, fn func(*Param) string, sep string) string {
   209  	if len(ps) == 0 {
   210  		return ""
   211  	}
   212  	a := make([]string, 0)
   213  	for _, p := range ps {
   214  		a = append(a, fn(p))
   215  	}
   216  	return strings.Join(a, sep)
   217  }
   218  
   219  // Rets describes function return parameters.
   220  type Rets struct {
   221  	Name         string
   222  	Type         string
   223  	ReturnsError bool
   224  	FailCond     string
   225  }
   226  
   227  // ErrorVarName returns error variable name for r.
   228  func (r *Rets) ErrorVarName() string {
   229  	if r.ReturnsError {
   230  		return "err"
   231  	}
   232  	if r.Type == "error" {
   233  		return r.Name
   234  	}
   235  	return ""
   236  }
   237  
   238  // ToParams converts r into slice of *Param.
   239  func (r *Rets) ToParams() []*Param {
   240  	ps := make([]*Param, 0)
   241  	if len(r.Name) > 0 {
   242  		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
   243  	}
   244  	if r.ReturnsError {
   245  		ps = append(ps, &Param{Name: "err", Type: "error"})
   246  	}
   247  	return ps
   248  }
   249  
   250  // List returns source code of syscall return parameters.
   251  func (r *Rets) List() string {
   252  	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   253  	if len(s) > 0 {
   254  		s = "(" + s + ")"
   255  	}
   256  	return s
   257  }
   258  
   259  // PrintList returns source code of trace printing part correspondent
   260  // to syscall return values.
   261  func (r *Rets) PrintList() string {
   262  	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   263  }
   264  
   265  // SetReturnValuesCode returns source code that accepts syscall return values.
   266  func (r *Rets) SetReturnValuesCode() string {
   267  	if r.Name == "" && !r.ReturnsError {
   268  		return ""
   269  	}
   270  	retvar := "r0"
   271  	if r.Name == "" {
   272  		retvar = "r1"
   273  	}
   274  	errvar := "_"
   275  	if r.ReturnsError {
   276  		errvar = "e1"
   277  	}
   278  	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
   279  }
   280  
   281  func (r *Rets) useLongHandleErrorCode(retvar string) string {
   282  	const code = `if %s {
   283  		if e1 != 0 {
   284  			err = errnoErr(e1)
   285  		} else {
   286  			err = %sEINVAL
   287  		}
   288  	}`
   289  	cond := retvar + " == 0"
   290  	if r.FailCond != "" {
   291  		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
   292  	}
   293  	return fmt.Sprintf(code, cond, syscalldot())
   294  }
   295  
   296  // SetErrorCode returns source code that sets return parameters.
   297  func (r *Rets) SetErrorCode() string {
   298  	const code = `if r0 != 0 {
   299  		%s = %sErrno(r0)
   300  	}`
   301  	if r.Name == "" && !r.ReturnsError {
   302  		return ""
   303  	}
   304  	if r.Name == "" {
   305  		return r.useLongHandleErrorCode("r1")
   306  	}
   307  	if r.Type == "error" {
   308  		return fmt.Sprintf(code, r.Name, syscalldot())
   309  	}
   310  	s := ""
   311  	switch {
   312  	case r.Type[0] == '*':
   313  		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
   314  	case r.Type == "bool":
   315  		s = fmt.Sprintf("%s = r0 != 0", r.Name)
   316  	default:
   317  		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
   318  	}
   319  	if !r.ReturnsError {
   320  		return s
   321  	}
   322  	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
   323  }
   324  
   325  // Fn describes syscall function.
   326  type Fn struct {
   327  	Name        string
   328  	Params      []*Param
   329  	Rets        *Rets
   330  	PrintTrace  bool
   331  	dllname     string
   332  	dllfuncname string
   333  	src         string
   334  	// TODO: get rid of this field and just use parameter index instead
   335  	curTmpVarIdx int // insure tmp variables have uniq names
   336  }
   337  
   338  // extractParams parses s to extract function parameters.
   339  func extractParams(s string, f *Fn) ([]*Param, error) {
   340  	s = trim(s)
   341  	if s == "" {
   342  		return nil, nil
   343  	}
   344  	a := strings.Split(s, ",")
   345  	ps := make([]*Param, len(a))
   346  	for i := range ps {
   347  		s2 := trim(a[i])
   348  		b := strings.Split(s2, " ")
   349  		if len(b) != 2 {
   350  			b = strings.Split(s2, "\t")
   351  			if len(b) != 2 {
   352  				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
   353  			}
   354  		}
   355  		ps[i] = &Param{
   356  			Name:      trim(b[0]),
   357  			Type:      trim(b[1]),
   358  			fn:        f,
   359  			tmpVarIdx: -1,
   360  		}
   361  	}
   362  	return ps, nil
   363  }
   364  
   365  // extractSection extracts text out of string s starting after start
   366  // and ending just before end. found return value will indicate success,
   367  // and prefix, body and suffix will contain correspondent parts of string s.
   368  func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
   369  	s = trim(s)
   370  	if strings.HasPrefix(s, string(start)) {
   371  		// no prefix
   372  		body = s[1:]
   373  	} else {
   374  		a := strings.SplitN(s, string(start), 2)
   375  		if len(a) != 2 {
   376  			return "", "", s, false
   377  		}
   378  		prefix = a[0]
   379  		body = a[1]
   380  	}
   381  	a := strings.SplitN(body, string(end), 2)
   382  	if len(a) != 2 {
   383  		return "", "", "", false
   384  	}
   385  	return prefix, a[0], a[1], true
   386  }
   387  
   388  // newFn parses string s and return created function Fn.
   389  func newFn(s string) (*Fn, error) {
   390  	s = trim(s)
   391  	f := &Fn{
   392  		Rets:       &Rets{},
   393  		src:        s,
   394  		PrintTrace: *printTraceFlag,
   395  	}
   396  	// function name and args
   397  	prefix, body, s, found := extractSection(s, '(', ')')
   398  	if !found || prefix == "" {
   399  		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
   400  	}
   401  	f.Name = prefix
   402  	var err error
   403  	f.Params, err = extractParams(body, f)
   404  	if err != nil {
   405  		return nil, err
   406  	}
   407  	// return values
   408  	_, body, s, found = extractSection(s, '(', ')')
   409  	if found {
   410  		r, err := extractParams(body, f)
   411  		if err != nil {
   412  			return nil, err
   413  		}
   414  		switch len(r) {
   415  		case 0:
   416  		case 1:
   417  			if r[0].IsError() {
   418  				f.Rets.ReturnsError = true
   419  			} else {
   420  				f.Rets.Name = r[0].Name
   421  				f.Rets.Type = r[0].Type
   422  			}
   423  		case 2:
   424  			if !r[1].IsError() {
   425  				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
   426  			}
   427  			f.Rets.ReturnsError = true
   428  			f.Rets.Name = r[0].Name
   429  			f.Rets.Type = r[0].Type
   430  		default:
   431  			return nil, errors.New("Too many return values in \"" + f.src + "\"")
   432  		}
   433  	}
   434  	// fail condition
   435  	_, body, s, found = extractSection(s, '[', ']')
   436  	if found {
   437  		f.Rets.FailCond = body
   438  	}
   439  	// dll and dll function names
   440  	s = trim(s)
   441  	if s == "" {
   442  		return f, nil
   443  	}
   444  	if !strings.HasPrefix(s, "=") {
   445  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   446  	}
   447  	s = trim(s[1:])
   448  	a := strings.Split(s, ".")
   449  	switch len(a) {
   450  	case 1:
   451  		f.dllfuncname = a[0]
   452  	case 2:
   453  		f.dllname = a[0]
   454  		f.dllfuncname = a[1]
   455  	default:
   456  		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
   457  	}
   458  	return f, nil
   459  }
   460  
   461  // DLLName returns DLL name for function f.
   462  func (f *Fn) DLLName() string {
   463  	if f.dllname == "" {
   464  		return "kernel32"
   465  	}
   466  	return f.dllname
   467  }
   468  
   469  // DLLName returns DLL function name for function f.
   470  func (f *Fn) DLLFuncName() string {
   471  	if f.dllfuncname == "" {
   472  		return f.Name
   473  	}
   474  	return f.dllfuncname
   475  }
   476  
   477  // ParamList returns source code for function f parameters.
   478  func (f *Fn) ParamList() string {
   479  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
   480  }
   481  
   482  // HelperParamList returns source code for helper function f parameters.
   483  func (f *Fn) HelperParamList() string {
   484  	return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
   485  }
   486  
   487  // ParamPrintList returns source code of trace printing part correspondent
   488  // to syscall input parameters.
   489  func (f *Fn) ParamPrintList() string {
   490  	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
   491  }
   492  
   493  // ParamCount return number of syscall parameters for function f.
   494  func (f *Fn) ParamCount() int {
   495  	n := 0
   496  	for _, p := range f.Params {
   497  		n += len(p.SyscallArgList())
   498  	}
   499  	return n
   500  }
   501  
   502  // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
   503  // to use. It returns parameter count for correspondent SyscallX function.
   504  func (f *Fn) SyscallParamCount() int {
   505  	n := f.ParamCount()
   506  	switch {
   507  	case n <= 3:
   508  		return 3
   509  	case n <= 6:
   510  		return 6
   511  	case n <= 9:
   512  		return 9
   513  	case n <= 12:
   514  		return 12
   515  	case n <= 15:
   516  		return 15
   517  	default:
   518  		panic("too many arguments to system call")
   519  	}
   520  }
   521  
   522  // Syscall determines which SyscallX function to use for function f.
   523  func (f *Fn) Syscall() string {
   524  	c := f.SyscallParamCount()
   525  	if c == 3 {
   526  		return syscalldot() + "Syscall"
   527  	}
   528  	return syscalldot() + "Syscall" + strconv.Itoa(c)
   529  }
   530  
   531  // SyscallParamList returns source code for SyscallX parameters for function f.
   532  func (f *Fn) SyscallParamList() string {
   533  	a := make([]string, 0)
   534  	for _, p := range f.Params {
   535  		a = append(a, p.SyscallArgList()...)
   536  	}
   537  	for len(a) < f.SyscallParamCount() {
   538  		a = append(a, "0")
   539  	}
   540  	return strings.Join(a, ", ")
   541  }
   542  
   543  // HelperCallParamList returns source code of call into function f helper.
   544  func (f *Fn) HelperCallParamList() string {
   545  	a := make([]string, 0, len(f.Params))
   546  	for _, p := range f.Params {
   547  		s := p.Name
   548  		if p.Type == "string" {
   549  			s = p.tmpVar()
   550  		}
   551  		a = append(a, s)
   552  	}
   553  	return strings.Join(a, ", ")
   554  }
   555  
   556  // IsUTF16 is true, if f is W (utf16) function. It is false
   557  // for all A (ascii) functions.
   558  func (f *Fn) IsUTF16() bool {
   559  	s := f.DLLFuncName()
   560  	return s[len(s)-1] == 'W'
   561  }
   562  
   563  // StrconvFunc returns name of Go string to OS string function for f.
   564  func (f *Fn) StrconvFunc() string {
   565  	if f.IsUTF16() {
   566  		return syscalldot() + "UTF16PtrFromString"
   567  	}
   568  	return syscalldot() + "BytePtrFromString"
   569  }
   570  
   571  // StrconvType returns Go type name used for OS string for f.
   572  func (f *Fn) StrconvType() string {
   573  	if f.IsUTF16() {
   574  		return "*uint16"
   575  	}
   576  	return "*byte"
   577  }
   578  
   579  // HasStringParam is true, if f has at least one string parameter.
   580  // Otherwise it is false.
   581  func (f *Fn) HasStringParam() bool {
   582  	for _, p := range f.Params {
   583  		if p.Type == "string" {
   584  			return true
   585  		}
   586  	}
   587  	return false
   588  }
   589  
   590  // HelperName returns name of function f helper.
   591  func (f *Fn) HelperName() string {
   592  	if !f.HasStringParam() {
   593  		return f.Name
   594  	}
   595  	return "_" + f.Name
   596  }
   597  
   598  // Source files and functions.
   599  type Source struct {
   600  	Funcs           []*Fn
   601  	Files           []string
   602  	StdLibImports   []string
   603  	ExternalImports []string
   604  }
   605  
   606  func (src *Source) Import(pkg string) {
   607  	src.StdLibImports = append(src.StdLibImports, pkg)
   608  	sort.Strings(src.StdLibImports)
   609  }
   610  
   611  func (src *Source) ExternalImport(pkg string) {
   612  	src.ExternalImports = append(src.ExternalImports, pkg)
   613  	sort.Strings(src.ExternalImports)
   614  }
   615  
   616  // ParseFiles parses files listed in fs and extracts all syscall
   617  // functions listed in sys comments. It returns source files
   618  // and functions collection *Source if successful.
   619  func ParseFiles(fs []string) (*Source, error) {
   620  	src := &Source{
   621  		Funcs: make([]*Fn, 0),
   622  		Files: make([]string, 0),
   623  		StdLibImports: []string{
   624  			"unsafe",
   625  		},
   626  		ExternalImports: make([]string, 0),
   627  	}
   628  	for _, file := range fs {
   629  		if err := src.ParseFile(file); err != nil {
   630  			return nil, err
   631  		}
   632  	}
   633  	return src, nil
   634  }
   635  
   636  // DLLs return dll names for a source set src.
   637  func (src *Source) DLLs() []string {
   638  	uniq := make(map[string]bool)
   639  	r := make([]string, 0)
   640  	for _, f := range src.Funcs {
   641  		name := f.DLLName()
   642  		if _, found := uniq[name]; !found {
   643  			uniq[name] = true
   644  			r = append(r, name)
   645  		}
   646  	}
   647  	return r
   648  }
   649  
   650  // ParseFile adds additional file path to a source set src.
   651  func (src *Source) ParseFile(path string) error {
   652  	file, err := os.Open(path)
   653  	if err != nil {
   654  		return err
   655  	}
   656  	defer file.Close()
   657  
   658  	s := bufio.NewScanner(file)
   659  	for s.Scan() {
   660  		t := trim(s.Text())
   661  		if len(t) < 7 {
   662  			continue
   663  		}
   664  		if !strings.HasPrefix(t, "//sys") {
   665  			continue
   666  		}
   667  		t = t[5:]
   668  		if !(t[0] == ' ' || t[0] == '\t') {
   669  			continue
   670  		}
   671  		f, err := newFn(t[1:])
   672  		if err != nil {
   673  			return err
   674  		}
   675  		src.Funcs = append(src.Funcs, f)
   676  	}
   677  	if err := s.Err(); err != nil {
   678  		return err
   679  	}
   680  	src.Files = append(src.Files, path)
   681  
   682  	// get package name
   683  	fset := token.NewFileSet()
   684  	_, err = file.Seek(0, 0)
   685  	if err != nil {
   686  		return err
   687  	}
   688  	pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
   689  	if err != nil {
   690  		return err
   691  	}
   692  	packageName = pkg.Name.Name
   693  
   694  	return nil
   695  }
   696  
   697  // IsStdRepo returns true if src is part of standard library.
   698  func (src *Source) IsStdRepo() (bool, error) {
   699  	if len(src.Files) == 0 {
   700  		return false, errors.New("no input files provided")
   701  	}
   702  	abspath, err := filepath.Abs(src.Files[0])
   703  	if err != nil {
   704  		return false, err
   705  	}
   706  	goroot := runtime.GOROOT()
   707  	if runtime.GOOS == "windows" {
   708  		abspath = strings.ToLower(abspath)
   709  		goroot = strings.ToLower(goroot)
   710  	}
   711  	sep := string(os.PathSeparator)
   712  	if !strings.HasSuffix(goroot, sep) {
   713  		goroot += sep
   714  	}
   715  	return strings.HasPrefix(abspath, goroot), nil
   716  }
   717  
   718  // Generate output source file from a source set src.
   719  func (src *Source) Generate(w io.Writer) error {
   720  	const (
   721  		pkgStd         = iota // any package in std library
   722  		pkgXSysWindows        // x/sys/windows package
   723  		pkgOther
   724  	)
   725  	isStdRepo, err := src.IsStdRepo()
   726  	if err != nil {
   727  		return err
   728  	}
   729  	var pkgtype int
   730  	switch {
   731  	case isStdRepo:
   732  		pkgtype = pkgStd
   733  	case packageName == "windows":
   734  		// TODO: this needs better logic than just using package name
   735  		pkgtype = pkgXSysWindows
   736  	default:
   737  		pkgtype = pkgOther
   738  	}
   739  	if *systemDLL {
   740  		switch pkgtype {
   741  		case pkgStd:
   742  			src.Import("internal/syscall/windows/sysdll")
   743  		case pkgXSysWindows:
   744  		default:
   745  			src.ExternalImport("golang.org/x/sys/windows")
   746  		}
   747  	}
   748  	if packageName != "syscall" {
   749  		src.Import("syscall")
   750  	}
   751  	funcMap := template.FuncMap{
   752  		"packagename": packagename,
   753  		"syscalldot":  syscalldot,
   754  		"newlazydll": func(dll string) string {
   755  			arg := "\"" + dll + ".dll\""
   756  			if !*systemDLL {
   757  				return syscalldot() + "NewLazyDLL(" + arg + ")"
   758  			}
   759  			switch pkgtype {
   760  			case pkgStd:
   761  				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
   762  			case pkgXSysWindows:
   763  				return "NewLazySystemDLL(" + arg + ")"
   764  			default:
   765  				return "windows.NewLazySystemDLL(" + arg + ")"
   766  			}
   767  		},
   768  	}
   769  	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
   770  	err = t.Execute(w, src)
   771  	if err != nil {
   772  		return errors.New("Failed to execute template: " + err.Error())
   773  	}
   774  	return nil
   775  }
   776  
   777  func usage() {
   778  	fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n")
   779  	flag.PrintDefaults()
   780  	os.Exit(1)
   781  }
   782  
   783  func main() {
   784  	flag.Usage = usage
   785  	flag.Parse()
   786  	if len(flag.Args()) <= 0 {
   787  		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
   788  		usage()
   789  	}
   790  
   791  	src, err := ParseFiles(flag.Args())
   792  	if err != nil {
   793  		log.Fatal(err)
   794  	}
   795  
   796  	var buf bytes.Buffer
   797  	if err := src.Generate(&buf); err != nil {
   798  		log.Fatal(err)
   799  	}
   800  
   801  	data, err := format.Source(buf.Bytes())
   802  	if err != nil {
   803  		log.Fatal(err)
   804  	}
   805  	if *filename == "" {
   806  		_, err = os.Stdout.Write(data)
   807  	} else {
   808  		err = ioutil.WriteFile(*filename, data, 0644)
   809  	}
   810  	if err != nil {
   811  		log.Fatal(err)
   812  	}
   813  }
   814  
   815  // TODO: use println instead to print in the following template
   816  const srcTemplate = `
   817  
   818  {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT.
   819  
   820  package {{packagename}}
   821  
   822  import (
   823  {{range .StdLibImports}}"{{.}}"
   824  {{end}}
   825  
   826  {{range .ExternalImports}}"{{.}}"
   827  {{end}}
   828  )
   829  
   830  var _ unsafe.Pointer
   831  
   832  // Do the interface allocations only once for common
   833  // Errno values.
   834  const (
   835  	errnoERROR_IO_PENDING = 997
   836  )
   837  
   838  var (
   839  	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
   840  )
   841  
   842  // errnoErr returns common boxed Errno values, to prevent
   843  // allocations at runtime.
   844  func errnoErr(e {{syscalldot}}Errno) error {
   845  	switch e {
   846  	case 0:
   847  		return nil
   848  	case errnoERROR_IO_PENDING:
   849  		return errERROR_IO_PENDING
   850  	}
   851  	// TODO: add more here, after collecting data on the common
   852  	// error values see on Windows. (perhaps when running
   853  	// all.bat?)
   854  	return e
   855  }
   856  
   857  var (
   858  {{template "dlls" .}}
   859  {{template "funcnames" .}})
   860  {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
   861  {{end}}
   862  
   863  {{/* help functions */}}
   864  
   865  {{define "dlls"}}{{range .DLLs}}	mod{{.}} = {{newlazydll .}}
   866  {{end}}{{end}}
   867  
   868  {{define "funcnames"}}{{range .Funcs}}	proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
   869  {{end}}{{end}}
   870  
   871  {{define "helperbody"}}
   872  func {{.Name}}({{.ParamList}}) {{template "results" .}}{
   873  {{template "helpertmpvars" .}}	return {{.HelperName}}({{.HelperCallParamList}})
   874  }
   875  {{end}}
   876  
   877  {{define "funcbody"}}
   878  func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
   879  {{template "tmpvars" .}}	{{template "syscall" .}}
   880  {{template "seterror" .}}{{template "printtrace" .}}	return
   881  }
   882  {{end}}
   883  
   884  {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}	{{.TmpVarHelperCode}}
   885  {{end}}{{end}}{{end}}
   886  
   887  {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
   888  {{end}}{{end}}{{end}}
   889  
   890  {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
   891  
   892  {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
   893  
   894  {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
   895  {{end}}{{end}}
   896  
   897  {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
   898  {{end}}{{end}}
   899  
   900  `
   901  

View as plain text