1
1
import inspect
2
- from typing import Any , Dict , Type
2
+
3
+ from typing import Any , Callable , Dict , List , Optional , TypeVar , Type
3
4
4
5
from dependency_injection .registration import Registration
5
6
from dependency_injection .scope import DEFAULT_SCOPE_NAME , Scope
6
7
from dependency_injection .utils .singleton_meta import SingletonMeta
7
8
9
+ Self = TypeVar ('Self' , bound = 'DependencyContainer' )
10
+
8
11
9
12
DEFAULT_CONTAINER_NAME = "default_container"
10
13
11
14
class DependencyContainer (metaclass = SingletonMeta ):
12
15
13
- def __init__ (self , name = None ):
16
+ def __init__ (self , name : str = None ):
14
17
self .name = name if name is not None else DEFAULT_CONTAINER_NAME
15
18
self ._registrations = {}
16
19
self ._singleton_instances = {}
17
20
self ._scoped_instances = {}
18
21
19
22
@classmethod
20
- def get_instance (cls , name = None ):
23
+ def get_instance (cls , name : str = None ) -> Self :
21
24
if name is None :
22
25
name = DEFAULT_CONTAINER_NAME
23
26
@@ -26,88 +29,118 @@ def get_instance(cls, name=None):
26
29
27
30
return cls ._instances [(cls , name )]
28
31
29
- def register_transient (self , interface , class_ , constructor_args = None ):
30
- if interface in self ._registrations :
31
- raise ValueError (f"Dependency { interface } is already registered." )
32
- self ._registrations [interface ] = Registration (interface , class_ , Scope .TRANSIENT , constructor_args )
33
-
34
- def register_scoped (self , interface , class_ , constructor_args = None ):
35
- if interface in self ._registrations :
36
- raise ValueError (f"Dependency { interface } is already registered." )
37
- self ._registrations [interface ] = Registration (interface , class_ , Scope .SCOPED , constructor_args )
38
-
39
- def register_singleton (self , interface , class_ , constructor_args = None ):
40
- if interface in self ._registrations :
41
- raise ValueError (f"Dependency { interface } is already registered." )
42
- self ._registrations [interface ] = Registration (interface , class_ , Scope .SINGLETON , constructor_args )
43
-
44
- def resolve (self , interface , scope_name = DEFAULT_SCOPE_NAME ):
32
+ def register_transient (self , dependency : Type , implementation : Optional [Type ] = None , tags : Optional [set ] = None , constructor_args : Optional [Dict [str , Any ]] = None ) -> None :
33
+ if implementation is None :
34
+ implementation = dependency
35
+ if dependency in self ._registrations :
36
+ raise ValueError (f"Dependency { dependency } is already registered." )
37
+ self ._registrations [dependency ] = Registration (dependency , implementation , Scope .TRANSIENT , tags , constructor_args )
38
+
39
+ def register_scoped (self , dependency : Type , implementation : Optional [Type ] = None , tags : Optional [set ] = None , constructor_args : Optional [Dict [str , Any ]] = None ) -> None :
40
+ if implementation is None :
41
+ implementation = dependency
42
+ if dependency in self ._registrations :
43
+ raise ValueError (f"Dependency { dependency } is already registered." )
44
+ self ._registrations [dependency ] = Registration (dependency , implementation , Scope .SCOPED , tags , constructor_args )
45
+
46
+ def register_singleton (self , dependency : Type , implementation : Optional [Type ] = None , tags : Optional [set ] = None , constructor_args : Optional [Dict [str , Any ]] = None ) -> None :
47
+ if implementation is None :
48
+ implementation = dependency
49
+ if dependency in self ._registrations :
50
+ raise ValueError (f"Dependency { dependency } is already registered." )
51
+ self ._registrations [dependency ] = Registration (dependency , implementation , Scope .SINGLETON , tags , constructor_args )
52
+
53
+ def register_factory (self , dependency : Type , factory : Callable [[Any ], Any ], factory_args : Optional [Dict [str , Any ]] = None , tags : Optional [set ] = None ) -> None :
54
+ if dependency in self ._registrations :
55
+ raise ValueError (f"Dependency { dependency } is already registered." )
56
+ self ._registrations [dependency ] = Registration (dependency , None , Scope .FACTORY , None , tags , factory , factory_args )
57
+
58
+ def register_instance (self , dependency : Type , instance : Any , tags : Optional [set ] = None ) -> None :
59
+ if dependency in self ._registrations :
60
+ raise ValueError (f"Dependency { dependency } is already registered." )
61
+ self ._registrations [dependency ] = Registration (dependency , type (instance ), Scope .SINGLETON , constructor_args = {}, tags = tags )
62
+ self ._singleton_instances [dependency ] = instance
63
+
64
+ def resolve (self , dependency : Type , scope_name : str = DEFAULT_SCOPE_NAME ) -> Type :
45
65
if scope_name not in self ._scoped_instances :
46
66
self ._scoped_instances [scope_name ] = {}
47
67
48
- if interface not in self ._registrations :
49
- raise KeyError (f"Dependency { interface .__name__ } is not registered." )
68
+ if dependency not in self ._registrations :
69
+ raise KeyError (f"Dependency { dependency .__name__ } is not registered." )
50
70
51
- registration = self ._registrations [interface ]
52
- dependency_scope = registration .scope
53
- dependency_class = registration .class_
71
+ registration = self ._registrations [dependency ]
72
+ scope = registration .scope
73
+ implementation = registration .implementation
54
74
constructor_args = registration .constructor_args
55
75
56
- self ._validate_constructor_args (constructor_args = constructor_args , class_ = dependency_class )
76
+ self ._validate_constructor_args (constructor_args = constructor_args , implementation = implementation )
57
77
58
- if dependency_scope == Scope .TRANSIENT :
78
+ if scope == Scope .TRANSIENT :
59
79
return self ._inject_dependencies (
60
- class_ = dependency_class ,
80
+ implementation = implementation ,
61
81
constructor_args = constructor_args
62
82
)
63
- elif dependency_scope == Scope .SCOPED :
64
- if interface not in self ._scoped_instances [scope_name ]:
65
- self ._scoped_instances [scope_name ][interface ] = (
83
+ elif scope == Scope .SCOPED :
84
+ if dependency not in self ._scoped_instances [scope_name ]:
85
+ self ._scoped_instances [scope_name ][dependency ] = (
66
86
self ._inject_dependencies (
67
- class_ = dependency_class ,
87
+ implementation = implementation ,
68
88
scope_name = scope_name ,
69
89
constructor_args = constructor_args ,
70
90
))
71
- return self ._scoped_instances [scope_name ][interface ]
72
- elif dependency_scope == Scope .SINGLETON :
73
- if interface not in self ._singleton_instances :
74
- self ._singleton_instances [interface ] = (
91
+ return self ._scoped_instances [scope_name ][dependency ]
92
+ elif scope == Scope .SINGLETON :
93
+ if dependency not in self ._singleton_instances :
94
+ self ._singleton_instances [dependency ] = (
75
95
self ._inject_dependencies (
76
- class_ = dependency_class ,
96
+ implementation = implementation ,
77
97
constructor_args = constructor_args
78
98
)
79
99
)
80
- return self ._singleton_instances [interface ]
81
-
82
- raise ValueError (f"Invalid dependency scope: { dependency_scope } " )
83
-
84
- def _validate_constructor_args (self , constructor_args : Dict [str , Any ], class_ : Type ) -> None :
85
- class_constructor = inspect .signature (class_ .__init__ ).parameters
100
+ return self ._singleton_instances [dependency ]
101
+ elif scope == Scope .FACTORY :
102
+ factory = registration .factory
103
+ factory_args = registration .factory_args or {}
104
+ return factory (** factory_args )
105
+
106
+ raise ValueError (f"Invalid dependency scope: { scope } " )
107
+
108
+ def resolve_all (self , tags : Optional [set ] = None ) -> List [Any ]:
109
+ tags = tags or []
110
+ resolved_dependencies = []
111
+ for registration in self ._registrations .values ():
112
+ if not len (tags ) or tags .intersection (registration .tags ):
113
+ resolved_dependencies .append (
114
+ self .resolve (registration .dependency ))
115
+ return resolved_dependencies
116
+
117
+ def _validate_constructor_args (self , constructor_args : Dict [str , Any ], implementation : Type ) -> None :
118
+ constructor = inspect .signature (implementation .__init__ ).parameters
86
119
87
120
# Check if any required parameter is missing
88
- missing_params = [param for param in class_constructor .keys () if
121
+ missing_params = [param for param in constructor .keys () if
89
122
param not in ["self" , "cls" , "args" , "kwargs" ] and
90
123
param not in constructor_args ]
91
124
if missing_params :
92
125
raise ValueError (
93
126
f"Missing required constructor arguments: "
94
- f"{ ', ' .join (missing_params )} for class '{ class_ .__name__ } '." )
127
+ f"{ ', ' .join (missing_params )} for class '{ implementation .__name__ } '." )
95
128
96
129
for arg_name , arg_value in constructor_args .items ():
97
- if arg_name not in class_constructor :
130
+ if arg_name not in constructor :
98
131
raise ValueError (
99
- f"Invalid constructor argument '{ arg_name } ' for class '{ class_ .__name__ } '. "
132
+ f"Invalid constructor argument '{ arg_name } ' for class '{ implementation .__name__ } '. "
100
133
f"The class does not have a constructor parameter with this name." )
101
134
102
- expected_type = class_constructor [arg_name ].annotation
135
+ expected_type = constructor [arg_name ].annotation
103
136
if expected_type != inspect .Parameter .empty :
104
137
if not isinstance (arg_value , expected_type ):
105
138
raise TypeError (
106
139
f"Constructor argument '{ arg_name } ' has an incompatible type. "
107
140
f"Expected type: { expected_type } , provided type: { type (arg_value )} ." )
108
141
109
- def _inject_dependencies (self , class_ , scope_name = None , constructor_args = None ):
110
- constructor = inspect .signature (class_ .__init__ )
142
+ def _inject_dependencies (self , implementation : Type , scope_name : str = None , constructor_args : Optional [ Dict [ str , Any ]] = None ) -> Type :
143
+ constructor = inspect .signature (implementation .__init__ )
111
144
params = constructor .parameters
112
145
113
146
dependencies = {}
@@ -127,4 +160,4 @@ def _inject_dependencies(self, class_, scope_name=None, constructor_args=None):
127
160
else :
128
161
dependencies [param_name ] = self .resolve (param_info .annotation , scope_name = scope_name )
129
162
130
- return class_ (** dependencies )
163
+ return implementation (** dependencies )
0 commit comments