diff --git a/README.md b/README.md index de6cb1d..b0fdffd 100644 --- a/README.md +++ b/README.md @@ -14,17 +14,19 @@ import ( ) func main() { - res, e := xmlrpc.Call( + res, e := xmlrpc.NewClient( "http://your-blog.example.com/xmlrpc.php", + ).Call( "metaWeblog.getRecentPosts", "blog-id", "user-id", "password", - 10) + 10, + ) if e != nil { log.Fatal(e) } - for _, p := range res.(xmlrpc.Array) { + for _, p := range res { for k, v := range p.(xmlrpc.Struct) { fmt.Printf("%s=%v\n", k, v) } diff --git a/xmlrpc.go b/xmlrpc.go index 02f6504..aa933ad 100644 --- a/xmlrpc.go +++ b/xmlrpc.go @@ -18,49 +18,23 @@ import ( type Array []interface{} type Struct map[string]interface{} -var xmlSpecial = map[byte]string{ - '<': "<", - '>': ">", - '"': """, - '\'': "'", - '&': "&", -} - -func xmlEscape(s string) string { - var b bytes.Buffer - for i := 0; i < len(s); i++ { - c := s[i] - if s, ok := xmlSpecial[c]; ok { - b.WriteString(s) - } else { - b.WriteByte(c) - } - } - return b.String() -} - -type valueNode struct { - Type string `xml:"attr"` - Body string `xml:"chardata"` -} - func next(p *xml.Decoder) (xml.Name, interface{}, error) { - se, e := nextStart(p) - if e != nil { - return xml.Name{}, nil, e + se, nextErr := nextStart(p) + if nextErr != nil { + return xml.Name{}, nil, nextErr } var nv interface{} switch se.Name.Local { case "string": var s string - if e = p.DecodeElement(&s, &se); e != nil { + if e := p.DecodeElement(&s, &se); e != nil { return xml.Name{}, nil, e } return xml.Name{}, s, nil case "boolean": var s string - if e = p.DecodeElement(&s, &se); e != nil { + if e := p.DecodeElement(&s, &se); e != nil { return xml.Name{}, nil, e } s = strings.TrimSpace(s) @@ -71,28 +45,28 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) { case "false", "0": b = false default: - e = errors.New("invalid boolean value") + return xml.Name{}, b, errors.New("invalid boolean value") } - return xml.Name{}, b, e + return xml.Name{}, b, nil case "int", "i1", "i2", "i4", "i8": var s string var i int - if e = p.DecodeElement(&s, &se); e != nil { + if e := p.DecodeElement(&s, &se); e != nil { return xml.Name{}, nil, e } - i, e = strconv.Atoi(strings.TrimSpace(s)) + i, e := strconv.Atoi(strings.TrimSpace(s)) return xml.Name{}, i, e case "double": var s string var f float64 - if e = p.DecodeElement(&s, &se); e != nil { + if e := p.DecodeElement(&s, &se); e != nil { return xml.Name{}, nil, e } - f, e = strconv.ParseFloat(strings.TrimSpace(s), 64) + f, e := strconv.ParseFloat(strings.TrimSpace(s), 64) return xml.Name{}, f, e case "dateTime.iso8601": var s string - if e = p.DecodeElement(&s, &se); e != nil { + if e := p.DecodeElement(&s, &se); e != nil { return xml.Name{}, nil, e } t, e := time.Parse("20060102T15:04:05", s) @@ -105,7 +79,7 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) { return xml.Name{}, t, e case "base64": var s string - if e = p.DecodeElement(&s, &se); e != nil { + if e := p.DecodeElement(&s, &se); e != nil { return xml.Name{}, nil, e } if b, e := base64.StdEncoding.DecodeString(s); e != nil { @@ -116,15 +90,15 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) { case "member": nextStart(p) return next(p) - case "value": - nextStart(p) - return next(p) - case "name": + + case "value", "name", "param": nextStart(p) return next(p) + case "struct": st := Struct{} + var e error se, e = nextStart(p) for e == nil && se.Name.Local == "member" { // name @@ -159,7 +133,8 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) { break } } - return xml.Name{}, st, nil + return xml.Name{}, st, e + case "array": var ar Array nextStart(p) // data @@ -172,14 +147,40 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) { ar = append(ar, value) } return xml.Name{}, ar, nil + case "nil": return xml.Name{}, nil, nil + + case "params": + var ar Array + for { + _, value, e := next(p) + if e != nil { + break + } + ar = append(ar, value) + } + return xml.Name{}, ar, nil + + case "fault": + _, value, _ := next(p) + fs, ok := value.(Struct) + if !ok { + return xml.Name{}, value, fmt.Errorf("fault: wanted Struct, got %#v", value) + } + var f Fault + if s, ok := fs["faultCode"].(string); ok { + f.Code, _ = strconv.Atoi(s) + } + f.Message, _ = fs["faultString"].(string) + return xml.Name{}, nil, &f + } - if e = p.DecodeElement(nv, &se); e != nil { + if e := p.DecodeElement(&nv, &se); e != nil { return xml.Name{}, nil, e } - return se.Name, nv, e + return se.Name, nv, nil } func nextStart(p *xml.Decoder) (xml.StartElement, error) { for { @@ -192,91 +193,116 @@ func nextStart(p *xml.Decoder) (xml.StartElement, error) { return t, nil } } - panic("unreachable") } -func toXml(v interface{}, typ bool) (s string) { +var UnsupportedType = errors.New("unsupported type") + +func writeXML(w io.Writer, v interface{}, typ bool) error { if v == nil { - return "" + _, err := io.WriteString(w, "") + return err } r := reflect.ValueOf(v) t := r.Type() k := t.Kind() if b, ok := v.([]byte); ok { - return "" + base64.StdEncoding.EncodeToString(b) + "" + io.WriteString(w, "") + _, err := base64.NewEncoder(base64.StdEncoding, w).Write(b) + io.WriteString(w, "") + return err } switch k { case reflect.Invalid: - panic("unsupported type") + return UnsupportedType case reflect.Bool: - return fmt.Sprintf("%v", v) + _, err := fmt.Fprintf(w, "%v", v) + return err case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if typ { - return fmt.Sprintf("%v", v) + _, err := fmt.Fprintf(w, "%v", v) + return err } - return fmt.Sprintf("%v", v) + _, err := fmt.Fprintf(w, "%v", v) + return err case reflect.Uintptr: - panic("unsupported type") + return UnsupportedType case reflect.Float32, reflect.Float64: if typ { - return fmt.Sprintf("%v", v) + _, err := fmt.Fprintf(w, "%v", v) + return err } - return fmt.Sprintf("%v", v) + _, err := fmt.Fprintf(w, "%v", v) + return err case reflect.Complex64, reflect.Complex128: - panic("unsupported type") + return UnsupportedType case reflect.Array: - s = "" + io.WriteString(w, "") for n := 0; n < r.Len(); n++ { - s += "" - s += toXml(r.Index(n).Interface(), typ) - s += "" + io.WriteString(w, "") + err := writeXML(w, r.Index(n).Interface(), typ) + io.WriteString(w, "") + if err != nil { + return err + } } - s += "" - return s + _, err := io.WriteString(w, "") + return err case reflect.Chan: - panic("unsupported type") + return UnsupportedType case reflect.Func: - panic("unsupported type") + return UnsupportedType case reflect.Interface: - return toXml(r.Elem(), typ) + return writeXML(w, r.Elem(), typ) case reflect.Map: - s = "" + io.WriteString(w, "") for _, key := range r.MapKeys() { - s += "" - s += "" + xmlEscape(key.Interface().(string)) + "" - s += "" + toXml(r.MapIndex(key).Interface(), typ) + "" - s += "" + io.WriteString(w, "") + if err := xml.EscapeText(w, []byte(key.Interface().(string))); err != nil { + return err + } + io.WriteString(w, "") + if err := writeXML(w, r.MapIndex(key).Interface(), typ); err != nil { + return err + } + if _, err := io.WriteString(w, ""); err != nil { + return err + } } - s += "" - return s + _, err := io.WriteString(w, "") + return err case reflect.Ptr: - panic("unsupported type") + return UnsupportedType case reflect.Slice: - panic("unsupported type") + return UnsupportedType case reflect.String: if typ { - return fmt.Sprintf("%v", xmlEscape(v.(string))) + io.WriteString(w, "") } - return xmlEscape(v.(string)) + err := xml.EscapeText(w, []byte(v.(string))) + if typ { + io.WriteString(w, "") + } + return err case reflect.Struct: - s = "" + io.WriteString(w, "") for n := 0; n < r.NumField(); n++ { - s += "" - s += "" + t.Field(n).Name + "" - s += "" + toXml(r.FieldByIndex([]int{n}).Interface(), true) + "" - s += "" + fmt.Fprintf(w, "%s", t.Field(n).Name) + if err := writeXML(w, r.FieldByIndex([]int{n}).Interface(), true); err != nil { + return err + } + io.WriteString(w, "") } - s += "" - return s + _, err := io.WriteString(w, "") + return err case reflect.UnsafePointer: - return toXml(r.Elem(), typ) + return writeXML(w, r.Elem(), typ) } - return + return nil } // Client is client of XMLRPC @@ -293,22 +319,43 @@ func NewClient(url string) *Client { } } -func makeRequest(name string, args ...interface{}) *bytes.Buffer { - buf := new(bytes.Buffer) - buf.WriteString(``) - buf.WriteString("" + xmlEscape(name) + "") - buf.WriteString("") +func Marshal(w io.Writer, name string, args ...interface{}) error { + io.WriteString(w, ``) + var end string + if name == "" { + io.WriteString(w, "") + end = "" + } else { + io.WriteString(w, "") + if err := xml.EscapeText(w, []byte(name)); err != nil { + return err + } + io.WriteString(w, "") + end = "" + } + + io.WriteString(w, "") for _, arg := range args { - buf.WriteString("") - buf.WriteString(toXml(arg, true)) - buf.WriteString("") + io.WriteString(w, "") + if err := writeXML(w, arg, true); err != nil { + return err + } + io.WriteString(w, "") } - buf.WriteString("") - return buf + io.WriteString(w, "") + _, err := io.WriteString(w, end) + return err } -func call(client *http.Client, url, name string, args ...interface{}) (v interface{}, e error) { - r, e := httpClient.Post(url, "text/xml", makeRequest(name, args...)) +func makeRequest(name string, args ...interface{}) *bytes.Buffer { + var buf bytes.Buffer + if err := Marshal(&buf, name, args...); err != nil { + panic(err) + } + return &buf +} +func call(client *http.Client, url, name string, args ...interface{}) (v Array, e error) { + r, e := http.DefaultClient.Post(url, "text/xml", makeRequest(name, args...)) if e != nil { return nil, e } @@ -322,37 +369,53 @@ func call(client *http.Client, url, name string, args ...interface{}) (v interfa return nil, errors.New(http.StatusText(http.StatusBadRequest)) } - p := xml.NewDecoder(r.Body) + _, v, e = Unmarshal(r.Body) + return v, e +} + +func Unmarshal(r io.Reader) (string, Array, error) { + var name string + p := xml.NewDecoder(r) se, e := nextStart(p) // methodResponse - if se.Name.Local != "methodResponse" { - return nil, errors.New("invalid response: missing methodResponse") - } - se, e = nextStart(p) // params - if se.Name.Local != "params" { - return nil, errors.New("invalid response: missing params") + if e != nil { + return name, nil, e } - se, e = nextStart(p) // param - if se.Name.Local != "param" { - return nil, errors.New("invalid response: missing param") + if se.Name.Local != "methodResponse" { + if se.Name.Local != "methodCall" { + return name, nil, errors.New("invalid response: missing methodResponse") + } + if se, e = nextStart(p); e != nil { + return name, nil, e + } + if se.Name.Local != "methodName" { + return name, nil, errors.New("invalid response: missing methodName") + } + if e = p.DecodeElement(&name, &se); e != nil { + return name, nil, e + } } - se, e = nextStart(p) // value - if se.Name.Local != "value" { - return nil, errors.New("invalid response: missing value") + _, v, e := next(p) + if a, ok := v.(Array); ok { + return name, a, e + } else if e == nil { + e = fmt.Errorf("wanted Array, got %#v", v) } - _, v, e = next(p) - return v, e + return name, nil, e } +type Fault struct { + Code int + Message string +} + +func (f *Fault) Error() string { return fmt.Sprintf("%d: %s", f.Code, f.Message) } + // Call call remote procedures function name with args -func (c *Client) Call(name string, args ...interface{}) (v interface{}, e error) { +func (c *Client) Call(name string, args ...interface{}) (v Array, e error) { return call(c.HttpClient, c.url, name, args...) } -// Global httpClient allows us to pool/reuse connections and not wastefully -// re-create transports for each request. -var httpClient = &http.Client{Transport: http.DefaultTransport, Timeout: 10 * time.Second} - // Call call remote procedures function name with args -func Call(url, name string, args ...interface{}) (v interface{}, e error) { - return call(httpClient, url, name, args...) +func Call(url, name string, args ...interface{}) (v Array, e error) { + return call(http.DefaultClient, url, name, args...) } diff --git a/xmlrpc_test.go b/xmlrpc_test.go index c605ae8..30203a7 100644 --- a/xmlrpc_test.go +++ b/xmlrpc_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -105,12 +106,12 @@ func TestAddInt(t *testing.T) { if err != nil { t.Fatal(err) } - i, ok := v.(int) + i, ok := v[0].(int) if !ok { - t.Fatalf("want int but got %T: %v", v, v) + t.Fatalf("want int but got %#v", v) } if i != 3 { - t.Fatalf("want %v but got %v", 3, v) + t.Fatalf("want %v but got %#v", 3, v) } } @@ -138,11 +139,19 @@ func TestAddString(t *testing.T) { if err != nil { t.Fatal(err) } - s, ok := v.(string) + s, ok := v[0].(string) if !ok { - t.Fatalf("want string but got %T: %v", v, v) + t.Fatalf("want string but got %#v", v) } if s != "helloworld" { t.Fatalf("want %q but got %q", "helloworld", v) } } + +func toXml(v interface{}, typ bool) (s string) { + var buf strings.Builder + if err := writeXML(&buf, v, typ); err != nil { + panic(err) + } + return buf.String() +}