Skip to content

Commit a01ef4e

Browse files
Code to handle multiple return values as tuples now logically complete, even though it doesn't work currently.
1 parent bedfb5c commit a01ef4e

File tree

2 files changed

+74
-27
lines changed

2 files changed

+74
-27
lines changed

bind/gen.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,28 @@ static inline void gopy_err_handle() {
8686
PyErr_Print();
8787
}
8888
}
89+
90+
static PyObject* Py_BuildValue1(const char *format, void* arg0)
91+
{
92+
return Py_BuildValue(format, arg0);
93+
}
94+
static PyObject* Py_BuildValue2(const char *format, void* arg0, void* arg1)
95+
{
96+
return Py_BuildValue(format, arg0, arg1);
97+
}
98+
static PyObject * Py_BuildValue3(const char *format, void* arg0, void* arg1, void* arg2)
99+
{
100+
return Py_BuildValue(format, arg0, arg1, arg2);
101+
}
102+
static PyObject * Py_BuildValue4(const char *format, void* arg0, void* arg1, void* arg2, void* arg3)
103+
{
104+
return Py_BuildValue(format, arg0, arg1, arg2, arg3);
105+
}
106+
static PyObject * Py_BuildValue5(const char *format, void* arg0, void* arg1, void* arg2, void* arg3, void* arg4)
107+
{
108+
return Py_BuildValue(format, arg0, arg1, arg2, arg3, arg4);
109+
}
110+
89111
%[9]s
90112
*/
91113
import "C"

bind/gen_func.go

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@ import (
1212
"strings"
1313
)
1414

15+
func buildPyTuple(fsym *Func) bool {
16+
npyres := len(fsym.sig.Results())
17+
if fsym.haserr {
18+
if !NoPyExceptions {
19+
npyres -= 1
20+
}
21+
}
22+
23+
return (npyres > 1)
24+
}
25+
1526
func (g *pyGen) recurse(gotype types.Type, prefix, name string) {
1627
switch t := gotype.(type) {
1728
case *types.Basic:
@@ -61,11 +72,8 @@ func (g *pyGen) genFuncSig(sym *symbol, fsym *Func) bool {
6172
nargs := 0
6273
nres := len(res)
6374
npyres := nres
64-
rvHasErr := false // set to true if the main return is an error
6575
if fsym.haserr {
66-
if NoPyExceptions {
67-
rvHasErr = true
68-
} else {
76+
if !NoPyExceptions {
6977
npyres -= 1
7078
}
7179
}
@@ -160,27 +168,16 @@ func (g *pyGen) genFuncSig(sym *symbol, fsym *Func) bool {
160168
goRet := ""
161169
if npyres == 0 {
162170
g.pybuild.Printf("None")
163-
} else if npyres == 1 {
164-
ret := res[0]
165-
sret := current.symtype(ret.GoType())
166-
if sret == nil {
167-
panic(fmt.Errorf(
168-
"gopy: could not find symbol for %q",
169-
ret.Name(),
170-
))
171-
}
172-
173-
if sret.cpyname == "PyObject*" {
174-
g.pybuild.Printf("retval('%s', caller_owns_return=True)", sret.cpyname)
175-
} else {
176-
g.pybuild.Printf("retval('%s')", sret.cpyname)
177-
}
178-
goRet = fmt.Sprintf("%s", sret.cgoname)
179-
} else {
180-
// On Python side, we are returning PyTuple.
171+
} else if buildPyTuple(fsym) {
172+
// We are returning PyTuple*. Setup pybindgen accordingly.
181173
g.pybuild.Printf("retval('PyObject*', caller_owns_return=True)")
182174

183-
// On Go side, we are returning multiple values.
175+
// On Go side, return *C.PyObject.
176+
goRet = "unsafe.Pointer"
177+
} else {
178+
ownership := ""
179+
pyrets := make([]string, npyres, npyres)
180+
gorets := make([]string, npyres, npyres)
184181
for i := 0; i < npyres; i++ {
185182
sret := current.symtype(res[i].GoType())
186183
if sret == nil {
@@ -189,11 +186,16 @@ func (g *pyGen) genFuncSig(sym *symbol, fsym *Func) bool {
189186
res[i].Name(),
190187
))
191188
}
192-
goRet += sret.cgoname
193-
if i != npyres-1 {
194-
goRet += ", "
189+
gorets[i] = sret.cgoname
190+
pyrets[i] = "'" + sret.cpyname + "'"
191+
if sret.cpyname == "PyObject*" {
192+
ownership = "caller_owns_return=True"
195193
}
196194
}
195+
196+
g.pybuild.Printf("retval(%s%s)", strings.Join(pyrets, ", "), ownership)
197+
198+
goRet = strings.Join(gorets, ", ")
197199
if npyres > 1 {
198200
goRet = "(" + goRet + ")"
199201
}
@@ -471,7 +473,30 @@ if __err != nil {
471473
}
472474
}
473475

474-
g.gofile.Printf("return %s", strings.Join(retvals[0:npyres], ", "))
476+
if buildPyTuple(fsym) {
477+
g.gofile.Printf("\n")
478+
formatStr := ""
479+
for i := 0; i < npyres; i++ {
480+
sret := current.symtype(res[i].GoType())
481+
if sret == nil {
482+
panic(fmt.Errorf(
483+
"gopy: could not find symbol for %q",
484+
res[i].Name(),
485+
))
486+
}
487+
if sret.pyfmt == "" {
488+
formatStr += "?"
489+
} else {
490+
formatStr += sret.pyfmt
491+
}
492+
}
493+
g.gofile.Printf("return unsafe.Pointer(C.Py_BuildValue%d(\"%s\", %s))\n",
494+
npyres,
495+
formatStr,
496+
strings.Join(retvals[0:npyres], ", "))
497+
} else {
498+
g.gofile.Printf("return %s\n", strings.Join(retvals[0:npyres], ", "))
499+
}
475500
} else {
476501
g.gofile.Printf("if boolPyToGo(goRun) {\n")
477502
g.gofile.Indent()

0 commit comments

Comments
 (0)