diff --git a/hibernate-core/src/main/java/org/hibernate/engine/spi/ExtensionStorage.java b/hibernate-core/src/main/java/org/hibernate/engine/spi/ExtensionStorage.java new file mode 100644 index 000000000000..eecdb740a7ad --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/engine/spi/ExtensionStorage.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.engine.spi; + +import org.hibernate.Incubating; + +import java.util.function.Supplier; + +/** + * Marker interface for extensions to register themselves within a session instance. + * + * @see SharedSessionContractImplementor#getExtensionStorage(Class, Supplier) + */ +@Incubating +public interface ExtensionStorage { +} diff --git a/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java b/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java index d66c9cc18f6c..adf2e8ea6bb5 100644 --- a/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/engine/spi/SessionDelegatorBaseImpl.java @@ -75,6 +75,7 @@ import java.util.Set; import java.util.TimeZone; import java.util.UUID; +import java.util.function.Supplier; /** * A wrapper class that delegates all method invocations to a delegate instance of @@ -517,6 +518,11 @@ public RootGraphImplementor getEntityGraph(String graphName) { return delegate.getEntityGraph( graphName ); } + @Override + public T getExtensionStorage(Class extension, Supplier createIfMissing) { + return delegate.getExtensionStorage( extension, createIfMissing ); + } + @Override public QueryImplementor createQuery(CriteriaSelect selectQuery) { return delegate.createQuery( selectQuery ); diff --git a/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionContractImplementor.java b/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionContractImplementor.java index b441357630f2..311653c1feb0 100644 --- a/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionContractImplementor.java +++ b/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionContractImplementor.java @@ -6,6 +6,8 @@ import java.util.Set; import java.util.UUID; +import java.util.function.Supplier; + import jakarta.persistence.TransactionRequiredException; import org.checkerframework.checker.nullness.qual.Nullable; @@ -621,4 +623,16 @@ default boolean isStatelessSession() { @Override RootGraphImplementor getEntityGraph(String graphName); + + /** + * Allows accessing session scoped extension storages of the particular session instance. + * + * @param extension The extension storage attached to the current session. + * @param createIfMissing Creates a storage extension using the supplier, + * if the current session does not yet have the particular storage type attached to this session. + * @param The type of the extension storage. + */ + @Incubating + T getExtensionStorage(Class extension, Supplier createIfMissing); + } diff --git a/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionDelegatorBaseImpl.java b/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionDelegatorBaseImpl.java index 5a9893154c61..1f31543f7964 100644 --- a/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionDelegatorBaseImpl.java +++ b/hibernate-core/src/main/java/org/hibernate/engine/spi/SharedSessionDelegatorBaseImpl.java @@ -52,6 +52,7 @@ import java.util.Set; import java.util.TimeZone; import java.util.UUID; +import java.util.function.Supplier; /** * A wrapper class that delegates all method invocations to a delegate instance of @@ -663,6 +664,11 @@ public RootGraphImplementor getEntityGraph(String graphName) { return delegate.getEntityGraph( graphName ); } + @Override + public T getExtensionStorage(Class extension, Supplier createIfMissing) { + return delegate.getExtensionStorage( extension, createIfMissing ); + } + @Override public List> getEntityGraphs(Class entityClass) { return delegate.getEntityGraphs( entityClass ); diff --git a/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java b/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java index fa06b7453326..d262afc69287 100644 --- a/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java +++ b/hibernate-core/src/main/java/org/hibernate/internal/AbstractSharedSessionContract.java @@ -41,6 +41,7 @@ import org.hibernate.engine.jdbc.spi.JdbcServices; import org.hibernate.engine.spi.EntityKey; import org.hibernate.engine.spi.ExceptionConverter; +import org.hibernate.engine.spi.ExtensionStorage; import org.hibernate.engine.spi.LoadQueryInfluencers; import org.hibernate.engine.spi.SessionEventListenerManager; import org.hibernate.engine.spi.SessionFactoryImplementor; @@ -111,12 +112,15 @@ import java.io.Serial; import java.sql.Connection; import java.sql.SQLException; +import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.TimeZone; import java.util.UUID; import java.util.function.Function; +import java.util.function.Supplier; import static java.lang.Boolean.TRUE; import static org.hibernate.boot.model.naming.Identifier.toIdentifier; @@ -186,6 +190,8 @@ public abstract class AbstractSharedSessionContract implements SharedSessionCont private transient ExceptionConverter exceptionConverter; private transient SessionAssociationMarkers sessionAssociationMarkers; + private transient Map, Object> extensionStorages; + public AbstractSharedSessionContract(SessionFactoryImpl factory, SessionCreationOptions options) { this.factory = factory; @@ -1704,6 +1710,20 @@ public SessionAssociationMarkers getSessionAssociationMarkers() { return sessionAssociationMarkers; } + @Override + public T getExtensionStorage(Class extension, Supplier createIfMissing) { + if ( extensionStorages == null ) { + extensionStorages = new HashMap<>(); + } + Object storage = extensionStorages.get( extension ); + if ( storage == null ) { + storage = createIfMissing.get(); + extensionStorages.put( extension, storage ); + } + + return extension.cast( storage ); + } + @Serial private void writeObject(ObjectOutputStream oos) throws IOException { SESSION_LOGGER.serializingSession( getSessionIdentifier() ); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/engine/spi/SessionExtensionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/engine/spi/SessionExtensionTest.java new file mode 100644 index 000000000000..eb8bd9ff8555 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/engine/spi/SessionExtensionTest.java @@ -0,0 +1,101 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.orm.test.engine.spi; + +import jakarta.persistence.Id; +import org.hibernate.engine.spi.ExtensionStorage; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@DomainModel(annotatedClasses = { + SessionExtensionTest.UselessEntity.class, +}) +@SessionFactory +public class SessionExtensionTest { + + @Test + public void failing(SessionFactoryScope scope) { + scope.inSession( sessionImplementor -> { + assertThatThrownBy( + () -> sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, MySometimesFailingExtensionStorage::new ) ) + .isInstanceOf( UnsupportedOperationException.class ); + } ); + + scope.inStatelessSession( sessionImplementor -> { + assertThatThrownBy( + () -> sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, MySometimesFailingExtensionStorage::new ) ) + .isInstanceOf( UnsupportedOperationException.class ); + } ); + } + + @Test + public void supplier(SessionFactoryScope scope) { + scope.inSession( sessionImplementor -> { + sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, + () -> new MySometimesFailingExtensionStorage( new HashMap<>() ) ) + .add( new Extension( 1 ) ); + + assertThat( sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, MySometimesFailingExtensionStorage::new ).get( 1 ) ) + .isNotNull() + .isEqualTo( new Extension( 1 ) ); + + assertThat( sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, + () -> new MySometimesFailingExtensionStorage( new HashMap<>() ) ).get( 1 ) ) + .isNotNull() + .isEqualTo( new Extension( 1 ) ); + } ); + + scope.inStatelessSession( sessionImplementor -> { + sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, + () -> new MySometimesFailingExtensionStorage( new HashMap<>() ) ) + .add( new Extension( 1 ) ); + + assertThat( sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, MySometimesFailingExtensionStorage::new ).get( 1 ) ) + .isNotNull() + .isEqualTo( new Extension( 1 ) ); + + assertThat( sessionImplementor.getExtensionStorage( MySometimesFailingExtensionStorage.class, + () -> new MySometimesFailingExtensionStorage( new HashMap<>() ) ).get( 1 ) ) + .isNotNull() + .isEqualTo( new Extension( 1 ) ); + } ); + } + + public static class MySometimesFailingExtensionStorage implements ExtensionStorage { + Map extensions = new HashMap<>(); + + public MySometimesFailingExtensionStorage() { + throw new UnsupportedOperationException(); + } + + MySometimesFailingExtensionStorage(Map extensions) { + this.extensions = extensions; + } + + public void add(Extension extension) { + extensions.put( extension.number, extension ); + } + + public Extension get(int number) { + return extensions.get( number ); + } + } + + public record Extension(int number) { + } + + static class UselessEntity { + @Id + Long id; + } +}