Skip to content

Commit 8798daa

Browse files
imalyginthenswan
andauthored
feat: 21644 Implemented getMerkleProof (#21765)
Signed-off-by: Ivan Malygin <ivan@swirldslabs.com> Co-authored-by: Nikita Lebedev <nikita.lebedev@limechain.tech>
1 parent 8c8b8e7 commit 8798daa

File tree

10 files changed

+727
-4
lines changed

10 files changed

+727
-4
lines changed

hedera-node/hedera-app/src/testFixtures/java/com/hedera/node/app/fixtures/state/FakeState.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.hedera.pbj.runtime.io.buffer.Bytes;
1010
import com.swirlds.common.merkle.MerkleNode;
1111
import com.swirlds.state.MerkleNodeState;
12+
import com.swirlds.state.MerkleProof;
1213
import com.swirlds.state.State;
1314
import com.swirlds.state.StateChangeListener;
1415
import com.swirlds.state.lifecycle.StateMetadata;
@@ -288,6 +289,11 @@ public Hash getHashForPath(long path) {
288289
throw new UnsupportedOperationException();
289290
}
290291

292+
@Override
293+
public MerkleProof getMerkleProof(long path) {
294+
throw new UnsupportedOperationException();
295+
}
296+
291297
@Override
292298
public long queueElementPath(final int stateId, @NonNull final Bytes expectedValue) {
293299
throw new UnsupportedOperationException();

platform-sdk/platform-apps/demos/StatsDemo/src/main/java/com/swirlds/demo/stats/StatsDemoState.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.swirlds.common.merkle.crypto.MerkleCryptographyFactory;
2020
import com.swirlds.common.metrics.noop.NoOpMetrics;
2121
import com.swirlds.state.MerkleNodeState;
22+
import com.swirlds.state.MerkleProof;
2223
import com.swirlds.state.test.fixtures.merkle.MerkleStateRoot;
2324
import edu.umd.cs.findbugs.annotations.NonNull;
2425
import org.hiero.base.constructable.ConstructableIgnored;
@@ -128,4 +129,9 @@ public long kvPath(final int stateId, @NonNull Bytes key) {
128129
public Hash getHashForPath(long path) {
129130
throw new UnsupportedOperationException();
130131
}
132+
133+
@Override
134+
public MerkleProof getMerkleProof(long path) {
135+
throw new UnsupportedOperationException();
136+
}
131137
}

platform-sdk/platform-apps/tests/PlatformTestingTool/src/main/java/com/swirlds/demo/platform/PlatformTestingToolState.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.swirlds.demo.platform.nft.ReferenceNftLedger;
2828
import com.swirlds.merkle.test.fixtures.map.pta.MapKey;
2929
import com.swirlds.state.MerkleNodeState;
30+
import com.swirlds.state.MerkleProof;
3031
import com.swirlds.state.test.fixtures.merkle.MerkleStateRoot;
3132
import com.swirlds.virtualmap.VirtualMap;
3233
import edu.umd.cs.findbugs.annotations.NonNull;
@@ -521,6 +522,11 @@ public Hash getHashForPath(long path) {
521522
throw new UnsupportedOperationException();
522523
}
523524

525+
@Override
526+
public MerkleProof getMerkleProof(long path) {
527+
throw new UnsupportedOperationException();
528+
}
529+
524530
/**
525531
* The version history of this class. Versions that have been released must NEVER be given a different value.
526532
*/

platform-sdk/swirlds-state-api/src/main/java/com/swirlds/state/MerkleNodeState.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,11 @@ default <V> long kvPath(final int stateId, @NonNull final V key, @NonNull final
134134
* @return hash of the merkle node at the given path or null if the path is non-existent
135135
*/
136136
Hash getHashForPath(long path);
137+
138+
/**
139+
* Prepares a Merkle proof for the given path.
140+
* @param path merkle path
141+
* @return Merkle proof for the given path or null if the path is non-existent
142+
*/
143+
MerkleProof getMerkleProof(long path);
137144
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
package com.swirlds.state;
3+
4+
import com.hedera.pbj.runtime.io.buffer.Bytes;
5+
import java.util.List;
6+
import org.hiero.base.crypto.Hash;
7+
8+
/**
9+
* Represents a Merkle proof containing all necessary information to verify a state item.
10+
*
11+
* @param stateItem byte representation of {@code StateItem}
12+
* @param siblingHashes a list of sibling hashes used in the Merkle proof from the leaf of {@code stateItem} to the root of the state
13+
* @param innerParentHashes a list of byte arrays representing inner parent hashes, where:
14+
* <ul>
15+
* <li><code>innerParentHashes.get(0)</code> is the hash of the Merkle leaf
16+
* <li><code>innerParentHashes.get(1)</code> is a hash of a parent</li>
17+
* <li><code>innerParentHashes.get(2)</code> is a hash of a grandparent</li>
18+
* <li>and so on</li>
19+
* </ul>
20+
*/
21+
public record MerkleProof(Bytes stateItem, List<SiblingHash> siblingHashes, List<Hash> innerParentHashes) {}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
package com.swirlds.state;
3+
4+
import org.hiero.base.crypto.Hash;
5+
6+
/**
7+
* A record for storing sibling hashes.
8+
* @param isRight true if this is a right sibling, false if this is a left sibling.
9+
* @param hash the hash of the sibling.
10+
*/
11+
public record SiblingHash(boolean isRight, Hash hash) {}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
package com.swirlds.state.merkle;
3+
4+
import static com.hedera.pbj.runtime.ProtoWriterTools.sizeOfTag;
5+
import static com.hedera.pbj.runtime.ProtoWriterTools.sizeOfVarInt32;
6+
import static com.hedera.pbj.runtime.ProtoWriterTools.writeDelimited;
7+
import static java.lang.StrictMath.toIntExact;
8+
import static java.util.Objects.requireNonNull;
9+
10+
import com.hedera.pbj.runtime.Codec;
11+
import com.hedera.pbj.runtime.FieldDefinition;
12+
import com.hedera.pbj.runtime.FieldType;
13+
import com.hedera.pbj.runtime.ParseException;
14+
import com.hedera.pbj.runtime.ProtoConstants;
15+
import com.hedera.pbj.runtime.ProtoParserTools;
16+
import com.hedera.pbj.runtime.ProtoWriterTools;
17+
import com.hedera.pbj.runtime.io.ReadableSequentialData;
18+
import com.hedera.pbj.runtime.io.WritableSequentialData;
19+
import com.hedera.pbj.runtime.io.buffer.Bytes;
20+
import edu.umd.cs.findbugs.annotations.NonNull;
21+
import java.io.IOException;
22+
23+
/**
24+
* A record to store state items.
25+
*
26+
* <p>This class is very similar to a class with the same name
27+
* generated from HAPI sources, com.hedera.hapi.platform.state.StateItem. The
28+
* generated class is not used in the current module to avoid a compile-time
29+
* dependency on HAPI.
30+
*
31+
* <p>At the bytes level, these two classes must be bit to bit identical. It means,
32+
* bytes for a state value record serialized using {@link StateItem.StateItemCodec} must be
33+
* identical to bytes created using HAPI StateItem and its codec. See StateValue definition in
34+
* virtual_map_state.proto for details.
35+
*
36+
* @param key key bytes
37+
* @param value state value object wrapping a domain value object
38+
*/
39+
public record StateItem(@NonNull Bytes key, @NonNull Bytes value) {
40+
public static final Codec<StateItem> CODEC = new StateItemCodec();
41+
42+
public StateItem {
43+
requireNonNull(key, "Null key");
44+
requireNonNull(value, "Null value");
45+
}
46+
47+
/**
48+
* Protobuf Codec for StateItem model object. Generated based on protobuf schema.
49+
*/
50+
public static final class StateItemCodec implements Codec<StateItem> {
51+
52+
static final FieldDefinition FIELD_KEY = new FieldDefinition("keyBytes", FieldType.BYTES, false, 2);
53+
static final FieldDefinition FIELD_VALUE = new FieldDefinition("keyBytes", FieldType.BYTES, false, 3);
54+
55+
/**
56+
* Parses a StateItem object from ProtoBuf bytes in a {@link ReadableSequentialData}. Throws if in strict mode ONLY.
57+
*
58+
* @param input The data input to parse data from, it is assumed to be in a state ready to read with position at start
59+
* of data to read and limit set at the end of data to read. The data inputs limit will be changed by this
60+
* method. If there are no bytes remaining in the data input,
61+
* then the method also returns immediately.
62+
* @param strictMode This parameter has no effect as {@code StateItem} has bytes only as fields
63+
* @param parseUnknownFields This parameter has no effect as {@code StateItem} has bytes only as fields
64+
* @param maxDepth This parameter has no effect as {@code StateItem} has no nested fields
65+
* @return Parsed StateItem model object
66+
* @throws ParseException If parsing fails
67+
*/
68+
public @NonNull StateItem parse(
69+
@NonNull final ReadableSequentialData input,
70+
final boolean strictMode,
71+
final boolean parseUnknownFields,
72+
final int maxDepth,
73+
final int maxSize)
74+
throws ParseException {
75+
76+
// read key tag
77+
final int firstFieldNum = extractFieldNum(input);
78+
Bytes keyBytes = null;
79+
Bytes valueBytes = null;
80+
if (firstFieldNum == FIELD_KEY.number()) {
81+
keyBytes = readBytes(input, FIELD_KEY);
82+
} else if (firstFieldNum == FIELD_VALUE.number()) {
83+
valueBytes = readBytes(input, FIELD_VALUE);
84+
} else {
85+
throw new ParseException("StateItem unknown field num: " + firstFieldNum);
86+
}
87+
88+
final int secondFieldNum = extractFieldNum(input);
89+
if (secondFieldNum == FIELD_KEY.number()) {
90+
keyBytes = readBytes(input, FIELD_KEY);
91+
} else if (secondFieldNum == FIELD_VALUE.number()) {
92+
valueBytes = readBytes(input, FIELD_VALUE);
93+
} else {
94+
throw new ParseException("StateItem unknown field num: " + secondFieldNum);
95+
}
96+
97+
assert keyBytes != null;
98+
assert valueBytes != null;
99+
100+
return new StateItem(keyBytes, valueBytes);
101+
}
102+
103+
private static int extractFieldNum(ReadableSequentialData input) throws ParseException {
104+
final int tag = input.readVarInt(false);
105+
final int wireType = tag & ProtoConstants.TAG_WIRE_TYPE_MASK;
106+
if (wireType != ProtoConstants.WIRE_TYPE_DELIMITED.ordinal()) {
107+
throw new ParseException("StateItem key wire type mismatch: expected="
108+
+ ProtoConstants.WIRE_TYPE_DELIMITED.ordinal() + ", actual=" + wireType);
109+
}
110+
return tag >> ProtoParserTools.TAG_FIELD_OFFSET;
111+
}
112+
113+
private static Bytes readBytes(ReadableSequentialData input, FieldDefinition fieldDefinition)
114+
throws ParseException {
115+
final ProtoConstants wireType = ProtoWriterTools.wireType(fieldDefinition);
116+
if (wireType != ProtoConstants.WIRE_TYPE_DELIMITED) {
117+
throw new ParseException("StateItem key wire type mismatch: expected="
118+
+ ProtoConstants.WIRE_TYPE_DELIMITED.ordinal() + ", actual=" + wireType);
119+
}
120+
121+
Bytes keyBytes;
122+
final int keySize = input.readVarInt(false);
123+
if (keySize == 0) {
124+
keyBytes = Bytes.EMPTY;
125+
} else {
126+
keyBytes = input.readBytes(keySize);
127+
}
128+
return keyBytes;
129+
}
130+
131+
/**
132+
* Write out a StateItem model to output stream in protobuf format.
133+
*
134+
* @param data The input model data to write
135+
* @param out The output stream to write to
136+
* @throws IOException If there is a problem writing
137+
*/
138+
public void write(@NonNull StateItem data, @NonNull final WritableSequentialData out) throws IOException {
139+
writeDelimited(out, FIELD_KEY, toIntExact(data.key.length()), v -> v.writeBytes(data.key));
140+
writeDelimited(out, FIELD_VALUE, toIntExact(data.value.length()), v -> v.writeBytes(data.value));
141+
}
142+
143+
/**
144+
* {@inheritDoc}
145+
*/
146+
public int measure(@NonNull final ReadableSequentialData input) throws ParseException {
147+
final var start = input.position();
148+
parse(input);
149+
final var end = input.position();
150+
return (int) (end - start);
151+
}
152+
153+
/**
154+
* {@inheritDoc}
155+
*/
156+
@Override
157+
public int measureRecord(StateItem item) {
158+
int size = 0;
159+
160+
size += sizeOfTag(FIELD_KEY);
161+
// key size counter size
162+
size += sizeOfVarInt32(toIntExact(item.key.length()));
163+
// Key size
164+
size += toIntExact(item.key.length());
165+
166+
size += sizeOfTag(FIELD_VALUE);
167+
// value size counter size
168+
size += sizeOfVarInt32(toIntExact(item.value().length()));
169+
// value size
170+
size += toIntExact(item.value.length());
171+
172+
return size;
173+
}
174+
175+
/**
176+
* {@inheritDoc}
177+
*/
178+
@Override
179+
public boolean fastEquals(@NonNull StateItem item, @NonNull ReadableSequentialData input)
180+
throws ParseException {
181+
return item.equals(parse(input));
182+
}
183+
184+
/**
185+
* {@inheritDoc}
186+
*/
187+
@Override
188+
public StateItem getDefaultInstance() {
189+
return new StateItem(Bytes.EMPTY, Bytes.EMPTY);
190+
}
191+
}
192+
}

platform-sdk/swirlds-state-impl/src/main/java/com/swirlds/state/merkle/VirtualMapState.java

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
import static com.swirlds.state.StateChangeListener.StateType.QUEUE;
66
import static com.swirlds.state.StateChangeListener.StateType.SINGLETON;
77
import static com.swirlds.state.lifecycle.StateMetadata.computeLabel;
8+
import static com.swirlds.state.merkle.StateItem.CODEC;
89
import static com.swirlds.state.merkle.disk.OnDiskQueueHelper.QUEUE_STATE_VALUE_CODEC;
910
import static com.swirlds.virtualmap.internal.Path.INVALID_PATH;
11+
import static com.swirlds.virtualmap.internal.Path.getParentPath;
12+
import static com.swirlds.virtualmap.internal.Path.getSiblingPath;
13+
import static com.swirlds.virtualmap.internal.Path.isRight;
1014
import static java.util.Objects.requireNonNull;
15+
import static org.hiero.base.crypto.Cryptography.NULL_HASH;
1116

1217
import com.hedera.pbj.runtime.Codec;
1318
import com.hedera.pbj.runtime.io.buffer.Bytes;
@@ -21,6 +26,8 @@
2126
import com.swirlds.merkledb.config.MerkleDbConfig;
2227
import com.swirlds.metrics.api.Metrics;
2328
import com.swirlds.state.MerkleNodeState;
29+
import com.swirlds.state.MerkleProof;
30+
import com.swirlds.state.SiblingHash;
2431
import com.swirlds.state.State;
2532
import com.swirlds.state.StateChangeListener;
2633
import com.swirlds.state.lifecycle.StateMetadata;
@@ -169,6 +176,7 @@ protected VirtualMapState(@NonNull final VirtualMapState<T> from) {
169176

170177
/**
171178
* Creates a copy of the instance.
179+
*
172180
* @return a copy of the instance
173181
*/
174182
protected abstract T copyingConstructor();
@@ -335,7 +343,7 @@ public void unregisterService(@NonNull final String serviceName) {
335343
* Removes the node and metadata from the state merkle tree.
336344
*
337345
* @param serviceName The service name. Cannot be null.
338-
* @param stateId The state ID
346+
* @param stateId The state ID
339347
*/
340348
public void removeServiceState(@NonNull final String serviceName, final int stateId) {
341349
virtualMap.throwIfImmutable();
@@ -431,6 +439,7 @@ public boolean isDestroyed() {
431439
/**
432440
* Release a reservation on a Virtual Map.
433441
* For more detailed docs, see {@link Reservable#release()}.
442+
*
434443
* @return true if this call to release() caused the Virtual Map to become destroyed
435444
*/
436445
public boolean release() {
@@ -628,7 +637,7 @@ public final class MerkleWritableStates extends MerkleStates implements Writable
628637
/**
629638
* Create a new instance
630639
*
631-
* @param serviceName cannot be null
640+
* @param serviceName cannot be null
632641
* @param stateMetadata cannot be null
633642
*/
634643
MerkleWritableStates(
@@ -855,6 +864,43 @@ public Hash getHashForPath(long path) {
855864
return virtualMap.getRecords().findHash(path);
856865
}
857866

867+
@Override
868+
public MerkleProof getMerkleProof(final long path) {
869+
if (!isHashed()) {
870+
throw new IllegalStateException("Cannot get Merkle proof for unhashed virtual map");
871+
}
872+
873+
VirtualLeafBytes<?> leafRecord = virtualMap.getRecords().findLeafRecord(path);
874+
if (leafRecord == null) {
875+
return null;
876+
}
877+
878+
final List<SiblingHash> siblingHashes = new ArrayList<>();
879+
final List<Hash> innerParentHashes = new ArrayList<>();
880+
881+
long currentPath = path;
882+
while (currentPath > 0) {
883+
final long siblingPath = getSiblingPath(currentPath);
884+
final boolean isSiblingRight = isRight(siblingPath);
885+
final Hash hashForPath = getHashForPath(siblingPath);
886+
final Hash normalizedHashForPath = hashForPath == null ? NULL_HASH : hashForPath;
887+
888+
siblingHashes.add(new SiblingHash(isSiblingRight, normalizedHashForPath));
889+
890+
innerParentHashes.add(getHashForPath(currentPath));
891+
892+
currentPath = getParentPath(currentPath);
893+
}
894+
895+
assert virtualMap.getHash() != null;
896+
897+
// add root hash
898+
innerParentHashes.add(virtualMap.getHash());
899+
900+
StateItem stateItem = new StateItem(leafRecord.keyBytes(), leafRecord.valueBytes());
901+
return new MerkleProof(CODEC.toBytes(stateItem), siblingHashes, innerParentHashes);
902+
}
903+
858904
/**
859905
* {@inheritDoc}
860906
*/

0 commit comments

Comments
 (0)