Skip to content

Commit c0dbff4

Browse files
committed
parallel container startup with deferred values
This commit tries to be as unobtrusive as possible, attaching new behavior to existing types where possible rather than building out new infrastructure. constructorNode returns a deferred value when called. On the first call, it asks paramList to start building an arg slice, which may also be deferred. Once the arg slice is resolved, constructorNode schedules its constructor function to be called. Once it's called, it resolves its own deferral. Multiple paramSingles can observe the same constructorNode before it's ready. If there's an error, they may all see the same error, which is a change in behavior. There are two schedulers: synchronous and parallel. The synchronous scheduler returns things in the same order as before. The parallel may not (and the tests that rely on shuffle order will fail). The scheduler needs to be flushed after deferred values are created. The synchronous scheduler does nothing on when flushing, but the parallel scheduler runs a pool of goroutines to resolve constructors. Calls to dig functions always happen on the same goroutine as Scope.Invoke(). Calls to constructor functions can happen on pooled goroutines. The choice of scheduler is up to the Scope. Whether constructor functions are safe to call in parallel seems most logically to be a property of the scope, and the scope is passed down the constructor/param call chain.
1 parent 1d9f0f1 commit c0dbff4

File tree

11 files changed

+546
-100
lines changed

11 files changed

+546
-100
lines changed

constructor.go

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,18 @@ type constructorNode struct {
4545
// id uniquely identifies the constructor that produces a node.
4646
id dot.CtorID
4747

48+
// Whether this node is already building its paramList and calling the constructor
49+
calling bool
50+
4851
// Whether the constructor owned by this node was already called.
4952
called bool
5053

5154
// Type information about constructor parameters.
5255
paramList paramList
5356

57+
// The result of calling the constructor
58+
deferred deferred
59+
5460
// Type information about constructor results.
5561
resultList resultList
5662

@@ -122,42 +128,66 @@ func (n *constructorNode) String() string {
122128
return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype)
123129
}
124130

125-
// Call calls this constructor if it hasn't already been called and
126-
// injects any values produced by it into the provided container.
127-
func (n *constructorNode) Call(c containerStore) error {
128-
if n.called {
129-
return nil
131+
// Call calls this constructor if it hasn't already been called and injects any values produced by it into the container
132+
// passed to newConstructorNode.
133+
//
134+
// If constructorNode has a unresolved deferred already in the process of building, it will return that one. If it has
135+
// already been successfully called, it will return an already-resolved deferred. Together these mean it will try the
136+
// call again if it failed last time.
137+
//
138+
// On failure, the returned pointer is not guaranteed to stay in a failed state; another call will reset it back to its
139+
// zero value; don't store the returned pointer. (It will still call each observer only once.)
140+
func (n *constructorNode) Call(c containerStore) *deferred {
141+
if n.calling || n.called {
142+
return &n.deferred
130143
}
131144

145+
n.calling = true
146+
n.deferred = deferred{}
147+
132148
if err := shallowCheckDependencies(c, n.paramList); err != nil {
133-
return errMissingDependencies{
149+
n.deferred.resolve(errMissingDependencies{
134150
Func: n.location,
135151
Reason: err,
136-
}
152+
})
137153
}
138154

139-
args, err := n.paramList.BuildList(c, false /* decorating */)
140-
if err != nil {
141-
return errArgumentsFailed{
142-
Func: n.location,
143-
Reason: err,
155+
var args []reflect.Value
156+
d := n.paramList.BuildList(c, false /* decorating */, &args)
157+
158+
d.observe(func(err error) {
159+
if err != nil {
160+
n.calling = false
161+
n.deferred.resolve(errArgumentsFailed{
162+
Func: n.location,
163+
Reason: err,
164+
})
165+
return
144166
}
145-
}
146-
147-
receiver := newStagingContainerWriter()
148-
results := c.invoker()(reflect.ValueOf(n.ctor), args)
149-
if err := n.resultList.ExtractList(receiver, false /* decorating */, results); err != nil {
150-
return errConstructorFailed{Func: n.location, Reason: err}
151-
}
152-
153-
// Commit the result to the original container that this constructor
154-
// was supplied to. The provided constructor is only used for a view of
155-
// the rest of the graph to instantiate the dependencies of this
156-
// container.
157-
receiver.Commit(n.s)
158-
n.called = true
159167

160-
return nil
168+
var results []reflect.Value
169+
170+
c.scheduler().schedule(func() {
171+
results = c.invoker()(reflect.ValueOf(n.ctor), args)
172+
}).observe(func(_ error) {
173+
n.calling = false
174+
receiver := newStagingContainerWriter()
175+
if err := n.resultList.ExtractList(receiver, false /* decorating */, results); err != nil {
176+
n.deferred.resolve(errConstructorFailed{Func: n.location, Reason: err})
177+
return
178+
}
179+
180+
// Commit the result to the original container that this constructor
181+
// was supplied to. The provided container is only used for a view of
182+
// the rest of the graph to instantiate the dependencies of this
183+
// container.
184+
receiver.Commit(n.s)
185+
n.called = true
186+
n.deferred.resolve(nil)
187+
})
188+
})
189+
190+
return &n.deferred
161191
}
162192

163193
// stagingContainerWriter is a containerWriter that records the changes that

constructor_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ func TestNodeAlreadyCalled(t *testing.T) {
5959
require.False(t, n.called, "node must not have been called")
6060

6161
c := New()
62-
require.NoError(t, n.Call(c.scope), "invoke failed")
62+
d := n.Call(c.scope)
63+
c.scope.sched.flush()
64+
require.NoError(t, d.err, "invoke failed")
6365
require.True(t, n.called, "node must be called")
64-
require.NoError(t, n.Call(c.scope), "calling again should be okay")
66+
d = n.Call(c.scope)
67+
c.scope.sched.flush()
68+
require.NoError(t, d.err, "calling again should be okay")
6569
}

container.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ type containerStore interface {
142142

143143
// Returns invokerFn function to use when calling arguments.
144144
invoker() invokerFn
145+
146+
// Returns the scheduler to use for this scope.
147+
scheduler() scheduler
145148
}
146149

147150
// New constructs a Container.
@@ -231,6 +234,29 @@ func dryInvoker(fn reflect.Value, _ []reflect.Value) []reflect.Value {
231234
return results
232235
}
233236

237+
type maxConcurrencyOption int
238+
239+
// MaxConcurrency run constructors in this container with a fixed pool of executor
240+
// goroutines. max is the number of goroutines to start.
241+
func MaxConcurrency(max int) Option {
242+
return maxConcurrencyOption(max)
243+
}
244+
245+
func (m maxConcurrencyOption) applyOption(container *Container) {
246+
container.scope.sched = &parallelScheduler{concurrency: int(m)}
247+
}
248+
249+
type unboundedConcurrency struct{}
250+
251+
// UnboundedConcurrency run constructors in this container as concurrently as possible.
252+
// Go's resource limits like GOMAXPROCS will inherently limit how much can happen in
253+
// parallel.
254+
var UnboundedConcurrency Option = unboundedConcurrency{}
255+
256+
func (u unboundedConcurrency) applyOption(container *Container) {
257+
container.scope.sched = &unboundedScheduler{}
258+
}
259+
234260
// String representation of the entire Container
235261
func (c *Container) String() string {
236262
return c.scope.String()

deferred.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package dig
2+
3+
type observer func(error)
4+
5+
// A deferred is an observable future result that may fail. Its zero value is unresolved and has no observers. It can
6+
// be resolved once, at which point every observer will be called.
7+
type deferred struct {
8+
observers []observer
9+
settled bool
10+
err error
11+
}
12+
13+
// alreadyResolved is a deferred that has already been resolved with a nil error.
14+
var alreadyResolved = deferred{settled: true}
15+
16+
// failedDeferred returns a deferred that is resolved with the given error.
17+
func failedDeferred(err error) *deferred {
18+
return &deferred{settled: true, err: err}
19+
}
20+
21+
// observe registers an observer to receive a callback when this deferred is resolved. It will be called at most one
22+
// time. If this deferred is already resolved, the observer is called immediately, before observe returns.
23+
func (d *deferred) observe(obs observer) {
24+
if d.settled {
25+
obs(d.err)
26+
return
27+
}
28+
29+
d.observers = append(d.observers, obs)
30+
}
31+
32+
// resolve sets the status of this deferred and notifies all observers if it's not already resolved.
33+
func (d *deferred) resolve(err error) {
34+
if d.settled {
35+
return
36+
}
37+
38+
d.settled = true
39+
d.err = err
40+
for _, obs := range d.observers {
41+
obs(err)
42+
}
43+
d.observers = nil
44+
}
45+
46+
// then returns a new deferred that is either resolved with the same error as this deferred, or any error returned from
47+
// the supplied function. The supplied function is only called if this deferred is resolved without error.
48+
func (d *deferred) then(res func() error) *deferred {
49+
d2 := new(deferred)
50+
d.observe(func(err error) {
51+
if err != nil {
52+
d2.resolve(err)
53+
return
54+
}
55+
d2.resolve(res())
56+
})
57+
return d2
58+
}
59+
60+
// catch maps any error from this deferred using the supplied function. The supplied function is only called if this
61+
// deferred is resolved with an error. If the supplied function returns a nil error, the new deferred will resolve
62+
// successfully.
63+
func (d *deferred) catch(rej func(error) error) *deferred {
64+
d2 := new(deferred)
65+
d.observe(func(err error) {
66+
if err != nil {
67+
err = rej(err)
68+
}
69+
d2.resolve(err)
70+
})
71+
return d2
72+
}
73+
74+
// whenAll returns a new deferred that resolves when all the supplied deferreds resolve. It resolves with the first
75+
// error reported by any deferred, or nil if they all succeed.
76+
func whenAll(others ...*deferred) *deferred {
77+
if len(others) == 0 {
78+
return &alreadyResolved
79+
}
80+
81+
d := new(deferred)
82+
count := len(others)
83+
84+
onResolved := func(err error) {
85+
if d.settled {
86+
return
87+
}
88+
89+
if err != nil {
90+
d.resolve(err)
91+
}
92+
93+
count--
94+
if count == 0 {
95+
d.resolve(nil)
96+
}
97+
}
98+
99+
for _, other := range others {
100+
other.observe(onResolved)
101+
}
102+
103+
return d
104+
}

dig_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"math/rand"
3030
"os"
3131
"reflect"
32+
"sync/atomic"
3233
"testing"
3334
"time"
3435

@@ -3566,3 +3567,91 @@ func TestEndToEndSuccessWithAliases(t *testing.T) {
35663567
})
35673568

35683569
}
3570+
3571+
func TestConcurrency(t *testing.T) {
3572+
// Ensures providers will run at the same time
3573+
t.Run("TestMaxConcurrency", func(t *testing.T) {
3574+
t.Parallel()
3575+
3576+
type (
3577+
A int
3578+
B int
3579+
C int
3580+
)
3581+
3582+
var (
3583+
timer = time.NewTimer(10 * time.Second)
3584+
max int32 = 3
3585+
done = make(chan struct{})
3586+
running int32 = 0
3587+
waitForUs = func() error {
3588+
if atomic.AddInt32(&running, 1) == max {
3589+
close(done)
3590+
}
3591+
select {
3592+
case <-timer.C:
3593+
return errors.New("timeout expired")
3594+
case <-done:
3595+
return nil
3596+
}
3597+
}
3598+
c = digtest.New(t, dig.MaxConcurrency(int(max)))
3599+
)
3600+
3601+
c.RequireProvide(func() (A, error) { return 0, waitForUs() })
3602+
c.RequireProvide(func() (B, error) { return 1, waitForUs() })
3603+
c.RequireProvide(func() (C, error) { return 2, waitForUs() })
3604+
3605+
c.RequireInvoke(func(a A, b B, c C) {
3606+
require.Equal(t, a, A(0))
3607+
require.Equal(t, b, B(1))
3608+
require.Equal(t, c, C(2))
3609+
require.Equal(t, running, int32(3))
3610+
})
3611+
})
3612+
3613+
t.Run("TestUnboundConcurrency", func(t *testing.T) {
3614+
t.Parallel()
3615+
3616+
var (
3617+
timer = time.NewTimer(10 * time.Second)
3618+
max int32 = 20
3619+
done = make(chan struct{})
3620+
running int32 = 0
3621+
waitForUs = func() error {
3622+
if atomic.AddInt32(&running, 1) >= max {
3623+
close(done)
3624+
}
3625+
select {
3626+
case <-timer.C:
3627+
return errors.New("timeout expired")
3628+
case <-done:
3629+
return nil
3630+
}
3631+
}
3632+
c = digtest.New(t, dig.UnboundedConcurrency)
3633+
expected []int
3634+
)
3635+
3636+
for i := 0; i < int(max); i++ {
3637+
i := i
3638+
expected = append(expected, i)
3639+
type out struct {
3640+
dig.Out
3641+
3642+
Value int `group:"a"`
3643+
}
3644+
c.RequireProvide(func() (out, error) { return out{Value: i}, waitForUs() })
3645+
}
3646+
3647+
type in struct {
3648+
dig.In
3649+
3650+
Values []int `group:"a"`
3651+
}
3652+
3653+
c.RequireInvoke(func(i in) {
3654+
require.ElementsMatch(t, expected, i.Values)
3655+
})
3656+
})
3657+
}

invoke.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,14 @@ func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error {
8282
s.isVerifiedAcyclic = true
8383
}
8484

85-
args, err := pl.BuildList(s, false /* decorating */)
85+
var args []reflect.Value
86+
87+
d := pl.BuildList(s, &args, false /* decorating */)
88+
d.observe(func(err2 error) {
89+
err = err2
90+
})
91+
s.sched.flush()
92+
8693
if err != nil {
8794
return errArgumentsFailed{
8895
Func: digreflect.InspectFunc(function),

0 commit comments

Comments
 (0)