Skip to content

Commit df02a88

Browse files
authored
Merge pull request #1 from TheWaWaR/fix-replay-to-pass-proof-1
Fix replay to pass proof
2 parents 77c86b5 + 2b8cead commit df02a88

File tree

14 files changed

+961
-601
lines changed

14 files changed

+961
-601
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "sparse-merkle-tree"
3-
version = "0.4.0-rc1"
3+
version = "0.5.0-rc1"
44
authors = ["jjy <jjyruby@gmail.com>"]
55
edition = "2018"
66
license = "MIT"

c/ckb_smt.h

Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
//
2+
// C implementation of SMT verification:
3+
// https://github.com/nervosnetwork/sparse-merkle-tree
4+
//
5+
// origin from:
6+
// https://github.com/nervosnetwork/godwoken/blob/6c9b92b9b06068a8678864b35a3272545ed7909e/c/gw_smt.h#L1
7+
#ifndef _CKB_SPARSE_MERKLE_TREE_H_
8+
#define _CKB_SPARSE_MERKLE_TREE_H_
9+
10+
// users can define a new stack size if needed
11+
#ifndef SMT_STACK_SIZE
12+
#define SMT_STACK_SIZE 257
13+
#endif
14+
15+
#define SMT_KEY_BYTES 32
16+
#define SMT_VALUE_BYTES 32
17+
18+
const uint8_t SMT_ZERO[SMT_VALUE_BYTES] = {0};
19+
20+
enum SMTErrorCode {
21+
// SMT
22+
ERROR_INSUFFICIENT_CAPACITY = 80,
23+
ERROR_NOT_FOUND,
24+
ERROR_INVALID_STACK,
25+
ERROR_INVALID_SIBLING,
26+
ERROR_INVALID_PROOF
27+
};
28+
29+
/* Key Value Pairs */
30+
typedef struct {
31+
uint8_t key[SMT_KEY_BYTES];
32+
uint8_t value[SMT_VALUE_BYTES];
33+
uint32_t order;
34+
} smt_pair_t;
35+
36+
typedef struct {
37+
smt_pair_t *pairs;
38+
uint32_t len;
39+
uint32_t capacity;
40+
} smt_state_t;
41+
42+
void smt_state_init(smt_state_t *state, smt_pair_t *buffer, uint32_t capacity) {
43+
state->pairs = buffer;
44+
state->len = 0;
45+
state->capacity = capacity;
46+
}
47+
48+
int smt_state_insert(smt_state_t *state, const uint8_t *key,
49+
const uint8_t *value) {
50+
if (state->len < state->capacity) {
51+
/* shortcut, append at end */
52+
memcpy(state->pairs[state->len].key, key, SMT_KEY_BYTES);
53+
memcpy(state->pairs[state->len].value, value, SMT_KEY_BYTES);
54+
state->len++;
55+
return 0;
56+
}
57+
58+
/* Find a matched key and overwritten it */
59+
int32_t i = state->len - 1;
60+
for (; i >= 0; i--) {
61+
if (memcmp(key, state->pairs[i].key, SMT_KEY_BYTES) == 0) {
62+
break;
63+
}
64+
}
65+
66+
if (i < 0) {
67+
return ERROR_INSUFFICIENT_CAPACITY;
68+
}
69+
70+
memcpy(state->pairs[i].value, value, SMT_VALUE_BYTES);
71+
return 0;
72+
}
73+
74+
int smt_state_fetch(smt_state_t *state, const uint8_t *key, uint8_t *value) {
75+
int32_t i = state->len - 1;
76+
for (; i >= 0; i--) {
77+
if (memcmp(key, state->pairs[i].key, SMT_KEY_BYTES) == 0) {
78+
memcpy(value, state->pairs[i].value, SMT_VALUE_BYTES);
79+
return 0;
80+
}
81+
}
82+
return ERROR_NOT_FOUND;
83+
}
84+
85+
int _smt_pair_cmp(const void *a, const void *b) {
86+
const smt_pair_t *pa = (const smt_pair_t *)a;
87+
const smt_pair_t *pb = (const smt_pair_t *)b;
88+
89+
for (int i = SMT_KEY_BYTES - 1; i >= 0; i--) {
90+
int cmp_result = pa->key[i] - pb->key[i];
91+
if (cmp_result != 0) {
92+
return cmp_result;
93+
}
94+
}
95+
return pa->order - pb->order;
96+
}
97+
98+
void smt_state_normalize(smt_state_t *state) {
99+
for (uint32_t i = 0; i < state->len; i++) {
100+
state->pairs[i].order = state->len - i;
101+
}
102+
qsort(state->pairs, state->len, sizeof(smt_pair_t), _smt_pair_cmp);
103+
/* Remove duplicate ones */
104+
int32_t sorted = 0, next = 0;
105+
while (next < (int32_t)state->len) {
106+
int32_t item_index = next++;
107+
while (next < (int32_t)state->len &&
108+
memcmp(state->pairs[item_index].key, state->pairs[next].key,
109+
SMT_KEY_BYTES) == 0) {
110+
next++;
111+
}
112+
if (item_index != sorted) {
113+
memcpy(state->pairs[sorted].key, state->pairs[item_index].key,
114+
SMT_KEY_BYTES);
115+
memcpy(state->pairs[sorted].value, state->pairs[item_index].value,
116+
SMT_VALUE_BYTES);
117+
}
118+
sorted++;
119+
}
120+
state->len = sorted;
121+
}
122+
123+
/* SMT */
124+
125+
int _smt_get_bit(const uint8_t *data, int offset) {
126+
int byte_pos = offset / 8;
127+
int bit_pos = offset % 8;
128+
return ((data[byte_pos] >> bit_pos) & 1) != 0;
129+
}
130+
131+
void _smt_set_bit(uint8_t *data, int offset) {
132+
int byte_pos = offset / 8;
133+
int bit_pos = offset % 8;
134+
data[byte_pos] |= 1 << bit_pos;
135+
}
136+
137+
void _smt_clear_bit(uint8_t *data, int offset) {
138+
int byte_pos = offset / 8;
139+
int bit_pos = offset % 8;
140+
data[byte_pos] &= (uint8_t)(~(1 << bit_pos));
141+
}
142+
143+
void _smt_copy_bits(uint8_t *source, int first_kept_bit) {
144+
int first_byte = first_kept_bit / 8;
145+
for (int i = 0; i < first_byte; i++) {
146+
source[i] = 0;
147+
}
148+
for (int i = first_byte * 8; i < first_kept_bit; i++) {
149+
_smt_clear_bit(source, i);
150+
}
151+
}
152+
153+
void _smt_parent_path(uint8_t *key, uint8_t height) {
154+
if (height == 255) {
155+
memset(key, 0, 32);
156+
} else {
157+
_smt_copy_bits(key, height + 1);
158+
}
159+
}
160+
161+
int _smt_zero_value(const uint8_t *value) {
162+
for (int i = 0; i < 32; i++) {
163+
if (value[i] != 0) {
164+
return 0;
165+
}
166+
}
167+
return 1;
168+
}
169+
170+
/* Notice that output might collide with one of lhs, or rhs */
171+
void _smt_merge(uint8_t height, const uint8_t *node_key, const uint8_t *lhs,
172+
const uint8_t *rhs, uint8_t *output) {
173+
if (_smt_zero_value(lhs) && _smt_zero_value(rhs)) {
174+
memcpy(output, SMT_ZERO, SMT_VALUE_BYTES);
175+
} else {
176+
blake2b_state blake2b_ctx;
177+
blake2b_init(&blake2b_ctx, 32);
178+
179+
blake2b_update(&blake2b_ctx, &height, 1);
180+
blake2b_update(&blake2b_ctx, node_key, 32);
181+
blake2b_update(&blake2b_ctx, lhs, 32);
182+
blake2b_update(&blake2b_ctx, rhs, 32);
183+
184+
blake2b_final(&blake2b_ctx, output, 32);
185+
}
186+
}
187+
188+
/*
189+
* Theoretically, a stack size of x should be able to process as many as
190+
* 2 ** (x - 1) updates. In this case with a stack size of 32, we can deal
191+
* with 2 ** 31 == 2147483648 updates, which is more than enough.
192+
*/
193+
int smt_calculate_root(uint8_t *buffer, const smt_state_t *pairs,
194+
const uint8_t *proof, uint32_t proof_length) {
195+
uint8_t stack_keys[SMT_STACK_SIZE][SMT_KEY_BYTES];
196+
uint8_t stack_values[SMT_STACK_SIZE][SMT_VALUE_BYTES];
197+
uint16_t stack_heights[SMT_STACK_SIZE] = {0};
198+
199+
uint32_t proof_index = 0;
200+
uint32_t leave_index = 0;
201+
uint32_t stack_top = 0;
202+
203+
while (proof_index < proof_length) {
204+
switch (proof[proof_index++]) {
205+
case 0x4C: {
206+
if (stack_top >= SMT_STACK_SIZE) {
207+
return ERROR_INVALID_STACK;
208+
}
209+
if (leave_index >= pairs->len) {
210+
return ERROR_INVALID_PROOF;
211+
}
212+
memcpy(stack_keys[stack_top], pairs->pairs[leave_index].key,
213+
SMT_KEY_BYTES);
214+
memcpy(stack_values[stack_top], pairs->pairs[leave_index].value,
215+
SMT_VALUE_BYTES);
216+
stack_heights[stack_top] = 0;
217+
stack_top++;
218+
leave_index++;
219+
} break;
220+
case 0x50: {
221+
if (stack_top == 0) {
222+
return ERROR_INVALID_STACK;
223+
}
224+
if (proof_index + 32 > proof_length) {
225+
return ERROR_INVALID_PROOF;
226+
}
227+
const uint8_t *sibling_node = &proof[proof_index];
228+
proof_index += 32;
229+
uint8_t *key = stack_keys[stack_top - 1];
230+
uint8_t *value = stack_values[stack_top - 1];
231+
uint16_t height = stack_heights[stack_top - 1];
232+
uint16_t *height_ptr = &stack_heights[stack_top - 1];
233+
if (height > 255) {
234+
return ERROR_INVALID_PROOF;
235+
}
236+
uint8_t parent_key[SMT_KEY_BYTES];
237+
memcpy(parent_key, key, SMT_KEY_BYTES);
238+
_smt_parent_path(parent_key, height);
239+
240+
// push value
241+
if (_smt_get_bit(key, height)) {
242+
_smt_merge((uint8_t)height, parent_key, sibling_node, value, value);
243+
} else {
244+
_smt_merge((uint8_t)height, parent_key, value, sibling_node, value);
245+
}
246+
// push key
247+
_smt_parent_path(key, height);
248+
// push height
249+
*height_ptr = height + 1;
250+
} break;
251+
case 0x48: {
252+
if (stack_top < 2) {
253+
return ERROR_INVALID_STACK;
254+
}
255+
if (proof_index >= proof_length) {
256+
return ERROR_INVALID_PROOF;
257+
}
258+
uint16_t *height_a_ptr = &stack_heights[stack_top - 2];
259+
260+
uint16_t height_a = stack_heights[stack_top - 2];
261+
uint8_t *key_a = stack_keys[stack_top - 2];
262+
uint8_t *value_a = stack_values[stack_top - 2];
263+
264+
uint16_t height_b = stack_heights[stack_top - 1];
265+
uint8_t *key_b = stack_keys[stack_top - 1];
266+
uint8_t *value_b = stack_values[stack_top - 1];
267+
stack_top -= 2;
268+
if (height_a != height_b) {
269+
return ERROR_INVALID_PROOF;
270+
}
271+
if (height_a > 255) {
272+
return ERROR_INVALID_PROOF;
273+
}
274+
uint8_t parent_key[SMT_KEY_BYTES];
275+
memcpy(parent_key, key_a, SMT_KEY_BYTES);
276+
_smt_parent_path(parent_key, (uint8_t)height_a);
277+
278+
// 2 keys should have same parent keys
279+
_smt_parent_path(key_b, (uint8_t)height_b);
280+
if (memcmp(parent_key, key_b, SMT_KEY_BYTES) != 0) {
281+
return ERROR_INVALID_PROOF;
282+
}
283+
// push value
284+
if (_smt_get_bit(key_a, height_a)) {
285+
_smt_merge(height_a, parent_key, value_b, value_a, value_a);
286+
} else {
287+
_smt_merge(height_a, parent_key, value_a, value_b, value_a);
288+
}
289+
// push key
290+
memcpy(key_a, parent_key, SMT_KEY_BYTES);
291+
// push height
292+
*height_a_ptr = height_a + 1;
293+
stack_top++;
294+
} break;
295+
case 0x4F: {
296+
if (stack_top < 1) {
297+
return ERROR_INVALID_STACK;
298+
}
299+
if (proof_index >= proof_length) {
300+
return ERROR_INVALID_PROOF;
301+
}
302+
uint16_t n = proof[proof_index];
303+
proof_index++;
304+
uint16_t zero_count = 0;
305+
if (n == 0) {
306+
zero_count = 256;
307+
} else {
308+
zero_count = n;
309+
}
310+
uint16_t *base_height_ptr = &stack_heights[stack_top - 1];
311+
uint16_t base_height = stack_heights[stack_top - 1];
312+
uint8_t *key = stack_keys[stack_top - 1];
313+
uint8_t *value = stack_values[stack_top - 1];
314+
if (base_height > 255) {
315+
return ERROR_INVALID_PROOF;
316+
}
317+
uint8_t parent_key[SMT_KEY_BYTES];
318+
memcpy(parent_key, key, SMT_KEY_BYTES);
319+
uint16_t height_u16 = base_height;
320+
for (uint16_t idx = 0; idx < zero_count; idx++) {
321+
height_u16 = base_height + idx;
322+
if (height_u16 > 255) {
323+
return ERROR_INVALID_PROOF;
324+
}
325+
// the following code can be omitted:
326+
// memcpy(parent_key, key, SMT_KEY_BYTES);
327+
// A key's parent's parent can be calculated from parent.
328+
// it's not needed to do it from scratch.
329+
// Make sure height_u16 is in increase order
330+
_smt_parent_path(parent_key, (uint8_t)height_u16);
331+
// push value
332+
if (_smt_get_bit(key, (uint8_t)height_u16)) {
333+
_smt_merge((uint8_t)height_u16, parent_key, SMT_ZERO, value, value);
334+
} else {
335+
_smt_merge((uint8_t)height_u16, parent_key, value, SMT_ZERO, value);
336+
}
337+
}
338+
// push key
339+
memcpy(key, parent_key, SMT_KEY_BYTES);
340+
// push height
341+
*base_height_ptr = height_u16 + 1;
342+
} break;
343+
default:
344+
return ERROR_INVALID_PROOF;
345+
}
346+
}
347+
if (stack_top != 1) {
348+
return ERROR_INVALID_STACK;
349+
}
350+
if (stack_heights[0] != 256) {
351+
return ERROR_INVALID_PROOF;
352+
}
353+
/* All leaves must be used */
354+
if (leave_index != pairs->len) {
355+
return ERROR_INVALID_PROOF;
356+
}
357+
358+
memcpy(buffer, stack_values[0], 32);
359+
return 0;
360+
}
361+
362+
int smt_verify(const uint8_t *hash, const smt_state_t *state,
363+
const uint8_t *proof, uint32_t proof_length) {
364+
uint8_t buffer[32];
365+
int ret = smt_calculate_root(buffer, state, proof, proof_length);
366+
if (ret != 0) {
367+
return ret;
368+
}
369+
if (memcmp(buffer, hash, 32) != 0) {
370+
return ERROR_INVALID_PROOF;
371+
}
372+
return 0;
373+
}
374+
375+
#endif

proptest-regressions/tests/tree.txt

Lines changed: 11 additions & 0 deletions
Large diffs are not rendered by default.

src/blake2b.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ impl Hasher for Blake2bHasher {
2121
fn write_h256(&mut self, h: &H256) {
2222
self.0.update(h.as_slice());
2323
}
24+
fn write_byte(&mut self, b: u8) {
25+
self.0.update(&[b][..]);
26+
}
2427
fn finish(self) -> H256 {
2528
let mut hash = [0u8; 32];
2629
self.0.finalize(&mut hash);

0 commit comments

Comments
 (0)