diff --git a/README.md b/README.md
index de6cb1d..b0fdffd 100644
--- a/README.md
+++ b/README.md
@@ -14,17 +14,19 @@ import (
)
func main() {
- res, e := xmlrpc.Call(
+ res, e := xmlrpc.NewClient(
"http://your-blog.example.com/xmlrpc.php",
+ ).Call(
"metaWeblog.getRecentPosts",
"blog-id",
"user-id",
"password",
- 10)
+ 10,
+ )
if e != nil {
log.Fatal(e)
}
- for _, p := range res.(xmlrpc.Array) {
+ for _, p := range res {
for k, v := range p.(xmlrpc.Struct) {
fmt.Printf("%s=%v\n", k, v)
}
diff --git a/xmlrpc.go b/xmlrpc.go
index 02f6504..aa933ad 100644
--- a/xmlrpc.go
+++ b/xmlrpc.go
@@ -18,49 +18,23 @@ import (
type Array []interface{}
type Struct map[string]interface{}
-var xmlSpecial = map[byte]string{
- '<': "<",
- '>': ">",
- '"': """,
- '\'': "'",
- '&': "&",
-}
-
-func xmlEscape(s string) string {
- var b bytes.Buffer
- for i := 0; i < len(s); i++ {
- c := s[i]
- if s, ok := xmlSpecial[c]; ok {
- b.WriteString(s)
- } else {
- b.WriteByte(c)
- }
- }
- return b.String()
-}
-
-type valueNode struct {
- Type string `xml:"attr"`
- Body string `xml:"chardata"`
-}
-
func next(p *xml.Decoder) (xml.Name, interface{}, error) {
- se, e := nextStart(p)
- if e != nil {
- return xml.Name{}, nil, e
+ se, nextErr := nextStart(p)
+ if nextErr != nil {
+ return xml.Name{}, nil, nextErr
}
var nv interface{}
switch se.Name.Local {
case "string":
var s string
- if e = p.DecodeElement(&s, &se); e != nil {
+ if e := p.DecodeElement(&s, &se); e != nil {
return xml.Name{}, nil, e
}
return xml.Name{}, s, nil
case "boolean":
var s string
- if e = p.DecodeElement(&s, &se); e != nil {
+ if e := p.DecodeElement(&s, &se); e != nil {
return xml.Name{}, nil, e
}
s = strings.TrimSpace(s)
@@ -71,28 +45,28 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
case "false", "0":
b = false
default:
- e = errors.New("invalid boolean value")
+ return xml.Name{}, b, errors.New("invalid boolean value")
}
- return xml.Name{}, b, e
+ return xml.Name{}, b, nil
case "int", "i1", "i2", "i4", "i8":
var s string
var i int
- if e = p.DecodeElement(&s, &se); e != nil {
+ if e := p.DecodeElement(&s, &se); e != nil {
return xml.Name{}, nil, e
}
- i, e = strconv.Atoi(strings.TrimSpace(s))
+ i, e := strconv.Atoi(strings.TrimSpace(s))
return xml.Name{}, i, e
case "double":
var s string
var f float64
- if e = p.DecodeElement(&s, &se); e != nil {
+ if e := p.DecodeElement(&s, &se); e != nil {
return xml.Name{}, nil, e
}
- f, e = strconv.ParseFloat(strings.TrimSpace(s), 64)
+ f, e := strconv.ParseFloat(strings.TrimSpace(s), 64)
return xml.Name{}, f, e
case "dateTime.iso8601":
var s string
- if e = p.DecodeElement(&s, &se); e != nil {
+ if e := p.DecodeElement(&s, &se); e != nil {
return xml.Name{}, nil, e
}
t, e := time.Parse("20060102T15:04:05", s)
@@ -105,7 +79,7 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
return xml.Name{}, t, e
case "base64":
var s string
- if e = p.DecodeElement(&s, &se); e != nil {
+ if e := p.DecodeElement(&s, &se); e != nil {
return xml.Name{}, nil, e
}
if b, e := base64.StdEncoding.DecodeString(s); e != nil {
@@ -116,15 +90,15 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
case "member":
nextStart(p)
return next(p)
- case "value":
- nextStart(p)
- return next(p)
- case "name":
+
+ case "value", "name", "param":
nextStart(p)
return next(p)
+
case "struct":
st := Struct{}
+ var e error
se, e = nextStart(p)
for e == nil && se.Name.Local == "member" {
// name
@@ -159,7 +133,8 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
break
}
}
- return xml.Name{}, st, nil
+ return xml.Name{}, st, e
+
case "array":
var ar Array
nextStart(p) // data
@@ -172,14 +147,40 @@ func next(p *xml.Decoder) (xml.Name, interface{}, error) {
ar = append(ar, value)
}
return xml.Name{}, ar, nil
+
case "nil":
return xml.Name{}, nil, nil
+
+ case "params":
+ var ar Array
+ for {
+ _, value, e := next(p)
+ if e != nil {
+ break
+ }
+ ar = append(ar, value)
+ }
+ return xml.Name{}, ar, nil
+
+ case "fault":
+ _, value, _ := next(p)
+ fs, ok := value.(Struct)
+ if !ok {
+ return xml.Name{}, value, fmt.Errorf("fault: wanted Struct, got %#v", value)
+ }
+ var f Fault
+ if s, ok := fs["faultCode"].(string); ok {
+ f.Code, _ = strconv.Atoi(s)
+ }
+ f.Message, _ = fs["faultString"].(string)
+ return xml.Name{}, nil, &f
+
}
- if e = p.DecodeElement(nv, &se); e != nil {
+ if e := p.DecodeElement(&nv, &se); e != nil {
return xml.Name{}, nil, e
}
- return se.Name, nv, e
+ return se.Name, nv, nil
}
func nextStart(p *xml.Decoder) (xml.StartElement, error) {
for {
@@ -192,91 +193,116 @@ func nextStart(p *xml.Decoder) (xml.StartElement, error) {
return t, nil
}
}
- panic("unreachable")
}
-func toXml(v interface{}, typ bool) (s string) {
+var UnsupportedType = errors.New("unsupported type")
+
+func writeXML(w io.Writer, v interface{}, typ bool) error {
if v == nil {
- return ""
+ _, err := io.WriteString(w, "")
+ return err
}
r := reflect.ValueOf(v)
t := r.Type()
k := t.Kind()
if b, ok := v.([]byte); ok {
- return "" + base64.StdEncoding.EncodeToString(b) + ""
+ io.WriteString(w, "")
+ _, err := base64.NewEncoder(base64.StdEncoding, w).Write(b)
+ io.WriteString(w, "")
+ return err
}
switch k {
case reflect.Invalid:
- panic("unsupported type")
+ return UnsupportedType
case reflect.Bool:
- return fmt.Sprintf("%v", v)
+ _, err := fmt.Fprintf(w, "%v", v)
+ return err
case reflect.Int,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if typ {
- return fmt.Sprintf("%v", v)
+ _, err := fmt.Fprintf(w, "%v", v)
+ return err
}
- return fmt.Sprintf("%v", v)
+ _, err := fmt.Fprintf(w, "%v", v)
+ return err
case reflect.Uintptr:
- panic("unsupported type")
+ return UnsupportedType
case reflect.Float32, reflect.Float64:
if typ {
- return fmt.Sprintf("%v", v)
+ _, err := fmt.Fprintf(w, "%v", v)
+ return err
}
- return fmt.Sprintf("%v", v)
+ _, err := fmt.Fprintf(w, "%v", v)
+ return err
case reflect.Complex64, reflect.Complex128:
- panic("unsupported type")
+ return UnsupportedType
case reflect.Array:
- s = ""
+ io.WriteString(w, "")
for n := 0; n < r.Len(); n++ {
- s += ""
- s += toXml(r.Index(n).Interface(), typ)
- s += ""
+ io.WriteString(w, "")
+ err := writeXML(w, r.Index(n).Interface(), typ)
+ io.WriteString(w, "")
+ if err != nil {
+ return err
+ }
}
- s += ""
- return s
+ _, err := io.WriteString(w, "")
+ return err
case reflect.Chan:
- panic("unsupported type")
+ return UnsupportedType
case reflect.Func:
- panic("unsupported type")
+ return UnsupportedType
case reflect.Interface:
- return toXml(r.Elem(), typ)
+ return writeXML(w, r.Elem(), typ)
case reflect.Map:
- s = ""
+ io.WriteString(w, "")
for _, key := range r.MapKeys() {
- s += ""
- s += "" + xmlEscape(key.Interface().(string)) + ""
- s += "" + toXml(r.MapIndex(key).Interface(), typ) + ""
- s += ""
+ io.WriteString(w, "")
+ if err := xml.EscapeText(w, []byte(key.Interface().(string))); err != nil {
+ return err
+ }
+ io.WriteString(w, "")
+ if err := writeXML(w, r.MapIndex(key).Interface(), typ); err != nil {
+ return err
+ }
+ if _, err := io.WriteString(w, ""); err != nil {
+ return err
+ }
}
- s += ""
- return s
+ _, err := io.WriteString(w, "")
+ return err
case reflect.Ptr:
- panic("unsupported type")
+ return UnsupportedType
case reflect.Slice:
- panic("unsupported type")
+ return UnsupportedType
case reflect.String:
if typ {
- return fmt.Sprintf("%v", xmlEscape(v.(string)))
+ io.WriteString(w, "")
}
- return xmlEscape(v.(string))
+ err := xml.EscapeText(w, []byte(v.(string)))
+ if typ {
+ io.WriteString(w, "")
+ }
+ return err
case reflect.Struct:
- s = ""
+ io.WriteString(w, "")
for n := 0; n < r.NumField(); n++ {
- s += ""
- s += "" + t.Field(n).Name + ""
- s += "" + toXml(r.FieldByIndex([]int{n}).Interface(), true) + ""
- s += ""
+ fmt.Fprintf(w, "%s", t.Field(n).Name)
+ if err := writeXML(w, r.FieldByIndex([]int{n}).Interface(), true); err != nil {
+ return err
+ }
+ io.WriteString(w, "")
}
- s += ""
- return s
+ _, err := io.WriteString(w, "")
+ return err
case reflect.UnsafePointer:
- return toXml(r.Elem(), typ)
+ return writeXML(w, r.Elem(), typ)
}
- return
+ return nil
}
// Client is client of XMLRPC
@@ -293,22 +319,43 @@ func NewClient(url string) *Client {
}
}
-func makeRequest(name string, args ...interface{}) *bytes.Buffer {
- buf := new(bytes.Buffer)
- buf.WriteString(``)
- buf.WriteString("" + xmlEscape(name) + "")
- buf.WriteString("")
+func Marshal(w io.Writer, name string, args ...interface{}) error {
+ io.WriteString(w, ``)
+ var end string
+ if name == "" {
+ io.WriteString(w, "")
+ end = ""
+ } else {
+ io.WriteString(w, "")
+ if err := xml.EscapeText(w, []byte(name)); err != nil {
+ return err
+ }
+ io.WriteString(w, "")
+ end = ""
+ }
+
+ io.WriteString(w, "")
for _, arg := range args {
- buf.WriteString("")
- buf.WriteString(toXml(arg, true))
- buf.WriteString("")
+ io.WriteString(w, "")
+ if err := writeXML(w, arg, true); err != nil {
+ return err
+ }
+ io.WriteString(w, "")
}
- buf.WriteString("")
- return buf
+ io.WriteString(w, "")
+ _, err := io.WriteString(w, end)
+ return err
}
-func call(client *http.Client, url, name string, args ...interface{}) (v interface{}, e error) {
- r, e := httpClient.Post(url, "text/xml", makeRequest(name, args...))
+func makeRequest(name string, args ...interface{}) *bytes.Buffer {
+ var buf bytes.Buffer
+ if err := Marshal(&buf, name, args...); err != nil {
+ panic(err)
+ }
+ return &buf
+}
+func call(client *http.Client, url, name string, args ...interface{}) (v Array, e error) {
+ r, e := http.DefaultClient.Post(url, "text/xml", makeRequest(name, args...))
if e != nil {
return nil, e
}
@@ -322,37 +369,53 @@ func call(client *http.Client, url, name string, args ...interface{}) (v interfa
return nil, errors.New(http.StatusText(http.StatusBadRequest))
}
- p := xml.NewDecoder(r.Body)
+ _, v, e = Unmarshal(r.Body)
+ return v, e
+}
+
+func Unmarshal(r io.Reader) (string, Array, error) {
+ var name string
+ p := xml.NewDecoder(r)
se, e := nextStart(p) // methodResponse
- if se.Name.Local != "methodResponse" {
- return nil, errors.New("invalid response: missing methodResponse")
- }
- se, e = nextStart(p) // params
- if se.Name.Local != "params" {
- return nil, errors.New("invalid response: missing params")
+ if e != nil {
+ return name, nil, e
}
- se, e = nextStart(p) // param
- if se.Name.Local != "param" {
- return nil, errors.New("invalid response: missing param")
+ if se.Name.Local != "methodResponse" {
+ if se.Name.Local != "methodCall" {
+ return name, nil, errors.New("invalid response: missing methodResponse")
+ }
+ if se, e = nextStart(p); e != nil {
+ return name, nil, e
+ }
+ if se.Name.Local != "methodName" {
+ return name, nil, errors.New("invalid response: missing methodName")
+ }
+ if e = p.DecodeElement(&name, &se); e != nil {
+ return name, nil, e
+ }
}
- se, e = nextStart(p) // value
- if se.Name.Local != "value" {
- return nil, errors.New("invalid response: missing value")
+ _, v, e := next(p)
+ if a, ok := v.(Array); ok {
+ return name, a, e
+ } else if e == nil {
+ e = fmt.Errorf("wanted Array, got %#v", v)
}
- _, v, e = next(p)
- return v, e
+ return name, nil, e
}
+type Fault struct {
+ Code int
+ Message string
+}
+
+func (f *Fault) Error() string { return fmt.Sprintf("%d: %s", f.Code, f.Message) }
+
// Call call remote procedures function name with args
-func (c *Client) Call(name string, args ...interface{}) (v interface{}, e error) {
+func (c *Client) Call(name string, args ...interface{}) (v Array, e error) {
return call(c.HttpClient, c.url, name, args...)
}
-// Global httpClient allows us to pool/reuse connections and not wastefully
-// re-create transports for each request.
-var httpClient = &http.Client{Transport: http.DefaultTransport, Timeout: 10 * time.Second}
-
// Call call remote procedures function name with args
-func Call(url, name string, args ...interface{}) (v interface{}, e error) {
- return call(httpClient, url, name, args...)
+func Call(url, name string, args ...interface{}) (v Array, e error) {
+ return call(http.DefaultClient, url, name, args...)
}
diff --git a/xmlrpc_test.go b/xmlrpc_test.go
index c605ae8..30203a7 100644
--- a/xmlrpc_test.go
+++ b/xmlrpc_test.go
@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
)
@@ -105,12 +106,12 @@ func TestAddInt(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- i, ok := v.(int)
+ i, ok := v[0].(int)
if !ok {
- t.Fatalf("want int but got %T: %v", v, v)
+ t.Fatalf("want int but got %#v", v)
}
if i != 3 {
- t.Fatalf("want %v but got %v", 3, v)
+ t.Fatalf("want %v but got %#v", 3, v)
}
}
@@ -138,11 +139,19 @@ func TestAddString(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- s, ok := v.(string)
+ s, ok := v[0].(string)
if !ok {
- t.Fatalf("want string but got %T: %v", v, v)
+ t.Fatalf("want string but got %#v", v)
}
if s != "helloworld" {
t.Fatalf("want %q but got %q", "helloworld", v)
}
}
+
+func toXml(v interface{}, typ bool) (s string) {
+ var buf strings.Builder
+ if err := writeXML(&buf, v, typ); err != nil {
+ panic(err)
+ }
+ return buf.String()
+}