Skip to content

Commit 666414a

Browse files
committed
Make Statement.get() async
1 parent abceb90 commit 666414a

File tree

7 files changed

+93
-54
lines changed

7 files changed

+93
-54
lines changed

index.d.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ export declare function databasePrepareSync(db: Database, sql: string): Statemen
2727
export declare function databaseSyncSync(db: Database): SyncResult
2828
/** Executes SQL in blocking mode. */
2929
export declare function databaseExecSync(db: Database, sql: string): void
30+
/** Gets first row from statement in blocking mode. */
31+
export declare function statementGetSync(stmt: Statement, params?: unknown | undefined | null): unknown
3032
/** Runs a statement in blocking mode. */
3133
export declare function statementRunSync(stmt: Statement, params?: unknown | undefined | null): RunResult
3234
export declare function statementIterateSync(stmt: Statement, params?: unknown | undefined | null): RowsIterator
@@ -159,7 +161,7 @@ export declare class Statement {
159161
* * `env` - The environment.
160162
* * `params` - The parameters to bind to the statement.
161163
*/
162-
get(params?: unknown | undefined | null): unknown
164+
get(params?: unknown | undefined | null): object
163165
/**
164166
* Create an iterator over the rows of a statement.
165167
*

index.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,14 @@ if (!nativeBinding) {
310310
throw new Error(`Failed to load native binding`)
311311
}
312312

313-
const { Database, databasePrepareSync, databaseSyncSync, databaseExecSync, Statement, statementRunSync, statementIterateSync, RowsIterator, iteratorNextSync, Record } = nativeBinding
313+
const { Database, databasePrepareSync, databaseSyncSync, databaseExecSync, Statement, statementGetSync, statementRunSync, statementIterateSync, RowsIterator, iteratorNextSync, Record } = nativeBinding
314314

315315
module.exports.Database = Database
316316
module.exports.databasePrepareSync = databasePrepareSync
317317
module.exports.databaseSyncSync = databaseSyncSync
318318
module.exports.databaseExecSync = databaseExecSync
319319
module.exports.Statement = Statement
320+
module.exports.statementGetSync = statementGetSync
320321
module.exports.statementRunSync = statementRunSync
321322
module.exports.statementIterateSync = statementIterateSync
322323
module.exports.RowsIterator = RowsIterator

integration-tests/tests/async.test.js

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ test.serial("Statement.run() [positional]", async (t) => {
5353

5454
// Verify that the data is inserted
5555
const stmt2 = await db.prepare("SELECT * FROM users WHERE id = 3");
56-
t.is(stmt2.get().name, "Carol");
57-
t.is(stmt2.get().email, "carol@example.net");
56+
t.is((await stmt2.get()).name, "Carol");
57+
t.is((await stmt2.get()).email, "carol@example.net");
5858
});
5959

6060
test.serial("Statement.get() returns no rows", async (t) => {
6161
const db = t.context.db;
6262

6363
const stmt = await db.prepare("SELECT * FROM users WHERE id = 0");
64-
t.is(stmt.get(), undefined);
64+
t.is((await stmt.get()), undefined);
6565
});
6666

6767
test.serial("Statement.get() [no parameters]", async (t) => {
@@ -70,7 +70,7 @@ test.serial("Statement.get() [no parameters]", async (t) => {
7070
var stmt = 0;
7171

7272
stmt = await db.prepare("SELECT * FROM users");
73-
t.is(stmt.get().name, "Alice");
73+
t.is((await stmt.get()).name, "Alice");
7474
t.deepEqual(await stmt.raw().get(), [1, 'Alice', 'alice@example.org']);
7575
});
7676

@@ -80,15 +80,15 @@ test.serial("Statement.get() [positional]", async (t) => {
8080
var stmt = 0;
8181

8282
stmt = await db.prepare("SELECT * FROM users WHERE id = ?");
83-
t.is(stmt.get(0), undefined);
84-
t.is(stmt.get([0]), undefined);
85-
t.is(stmt.get(1).name, "Alice");
86-
t.is(stmt.get(2).name, "Bob");
83+
t.is((await stmt.get(0)), undefined);
84+
t.is((await stmt.get([0])), undefined);
85+
t.is((await stmt.get(1)).name, "Alice");
86+
t.is((await stmt.get(2)).name, "Bob");
8787

8888
stmt = await db.prepare("SELECT * FROM users WHERE id = ?1");
89-
t.is(stmt.get({1: 0}), undefined);
90-
t.is(stmt.get({1: 1}).name, "Alice");
91-
t.is(stmt.get({1: 2}).name, "Bob");
89+
t.is((await stmt.get({1: 0})), undefined);
90+
t.is((await stmt.get({1: 1})).name, "Alice");
91+
t.is((await stmt.get({1: 2})).name, "Bob");
9292
});
9393

9494
test.serial("Statement.get() [named]", async (t) => {
@@ -97,27 +97,27 @@ test.serial("Statement.get() [named]", async (t) => {
9797
var stmt = undefined;
9898

9999
stmt = await db.prepare("SELECT * FROM users WHERE id = :id");
100-
t.is(stmt.get({ id: 0 }), undefined);
101-
t.is(stmt.get({ id: 1 }).name, "Alice");
102-
t.is(stmt.get({ id: 2 }).name, "Bob");
100+
t.is((await stmt.get({ id: 0 })), undefined);
101+
t.is((await stmt.get({ id: 1 })).name, "Alice");
102+
t.is((await stmt.get({ id: 2 })).name, "Bob");
103103

104104
stmt = await db.prepare("SELECT * FROM users WHERE id = @id");
105-
t.is(stmt.get({ id: 0 }), undefined);
106-
t.is(stmt.get({ id: 1 }).name, "Alice");
107-
t.is(stmt.get({ id: 2 }).name, "Bob");
105+
t.is((await stmt.get({ id: 0 })), undefined);
106+
t.is((await stmt.get({ id: 1 })).name, "Alice");
107+
t.is((await stmt.get({ id: 2 })).name, "Bob");
108108

109109
stmt = await db.prepare("SELECT * FROM users WHERE id = $id");
110-
t.is(stmt.get({ id: 0 }), undefined);
111-
t.is(stmt.get({ id: 1 }).name, "Alice");
112-
t.is(stmt.get({ id: 2 }).name, "Bob");
110+
t.is((await stmt.get({ id: 0 })), undefined);
111+
t.is((await stmt.get({ id: 1 })).name, "Alice");
112+
t.is((await stmt.get({ id: 2 })).name, "Bob");
113113
});
114114

115115

116116
test.serial("Statement.get() [raw]", async (t) => {
117117
const db = t.context.db;
118118

119119
const stmt = await db.prepare("SELECT * FROM users WHERE id = ?");
120-
t.deepEqual(stmt.raw().get(1), [1, "Alice", "alice@example.org"]);
120+
t.deepEqual(await stmt.raw().get(1), [1, "Alice", "alice@example.org"]);
121121
});
122122

123123
test.serial("Statement.iterate() [empty]", async (t) => {
@@ -267,9 +267,9 @@ test.serial("Database.transaction()", async (t) => {
267267
t.is(db.inTransaction, false);
268268

269269
const stmt = await db.prepare("SELECT * FROM users WHERE id = ?");
270-
t.is(stmt.get(3).name, "Joey");
271-
t.is(stmt.get(4).name, "Sally");
272-
t.is(stmt.get(5).name, "Junior");
270+
t.is((await stmt.get(3)).name, "Joey");
271+
t.is((await stmt.get(4)).name, "Sally");
272+
t.is((await stmt.get(5)).name, "Junior");
273273
});
274274

275275
test.serial("Database.transaction().immediate()", async (t) => {

integration-tests/tests/concurrency.test.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ test("Concurrent reads", async (t) => {
3232

3333
const promises = [];
3434
for (let i = 0; i < 100; i++) {
35-
promises.push(stmt.get(t.context.aliceId));
36-
promises.push(stmt.get(t.context.bobId));
35+
promises.push(await stmt.get(t.context.aliceId));
36+
promises.push(await stmt.get(t.context.bobId));
3737
}
3838

3939
const results = await Promise.all(promises);

promise.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ class Statement {
270270
*
271271
* @param bindParameters - The bind parameters for executing the statement.
272272
*/
273-
get(...bindParameters) {
273+
async get(...bindParameters) {
274274
try {
275-
return this.stmt.get(...bindParameters);
275+
return await this.stmt.get(...bindParameters);
276276
} catch (err) {
277277
throw convertError(err);
278278
}

src/lib.rs

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ impl Statement {
660660
let start = std::time::Instant::now();
661661
let stmt = self.stmt.clone();
662662
let conn = self.conn.clone();
663-
663+
664664
let future = async move {
665665
stmt.run(params).await.map_err(Error::from)?;
666666
let changes = if conn.total_changes() == total_changes_before {
@@ -676,10 +676,8 @@ impl Statement {
676676
lastInsertRowid: last_insert_row_id,
677677
})
678678
};
679-
680-
env.execute_tokio_future(future, move |&mut _env, result| {
681-
Ok(result)
682-
})
679+
680+
env.execute_tokio_future(future, move |&mut _env, result| Ok(result))
683681
}
684682

685683
/// Executes a SQL statement and returns the first row.
@@ -689,35 +687,35 @@ impl Statement {
689687
/// * `env` - The environment.
690688
/// * `params` - The parameters to bind to the statement.
691689
#[napi]
692-
pub fn get(&self, env: Env, params: Option<napi::JsUnknown>) -> Result<napi::JsUnknown> {
693-
let rt = runtime()?;
694-
690+
pub fn get(&self, env: Env, params: Option<napi::JsUnknown>) -> Result<napi::JsObject> {
695691
let safe_ints = self.mode.safe_ints.load(Ordering::SeqCst);
696692
let raw = self.mode.raw.load(Ordering::SeqCst);
697693
let pluck = self.mode.pluck.load(Ordering::SeqCst);
698694
let timed = self.mode.timing.load(Ordering::SeqCst);
699695

696+
let params = map_params(&self.stmt, params)?;
697+
let stmt = self.stmt.clone();
698+
let column_names = self.column_names.clone();
699+
700700
let start = if timed {
701701
Some(std::time::Instant::now())
702702
} else {
703703
None
704704
};
705-
rt.block_on(async move {
706-
let params = map_params(&self.stmt, params)?;
707-
let mut rows = self.stmt.query(params).await.map_err(Error::from)?;
705+
706+
let stmt_fut = stmt.clone();
707+
let future = async move {
708+
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
708709
let row = rows.next().await.map_err(Error::from)?;
709710
let duration: Option<f64> = start.map(|start| start.elapsed().as_secs_f64());
710-
let result = Self::get_internal(
711-
&env,
712-
&row,
713-
&self.column_names,
714-
safe_ints,
715-
raw,
716-
pluck,
717-
duration,
718-
);
719-
self.stmt.reset();
720-
result
711+
Ok((row, duration))
712+
};
713+
714+
env.execute_tokio_future(future, move |&mut env, (row, duration)| {
715+
let result =
716+
Self::get_internal(&env, &row, &column_names, safe_ints, raw, pluck, duration);
717+
stmt.reset();
718+
Ok(result)
721719
})
722720
}
723721

@@ -871,6 +869,44 @@ impl Statement {
871869
}
872870
}
873871

872+
/// Gets first row from statement in blocking mode.
873+
#[napi]
874+
pub fn statement_get_sync(
875+
stmt: &Statement,
876+
env: Env,
877+
params: Option<napi::JsUnknown>,
878+
) -> Result<napi::JsUnknown> {
879+
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
880+
let raw = stmt.mode.raw.load(Ordering::SeqCst);
881+
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
882+
let timed = stmt.mode.timing.load(Ordering::SeqCst);
883+
884+
let start = if timed {
885+
Some(std::time::Instant::now())
886+
} else {
887+
None
888+
};
889+
890+
let rt = runtime()?;
891+
rt.block_on(async move {
892+
let params = map_params(&stmt.stmt, params)?;
893+
let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
894+
let row = rows.next().await.map_err(Error::from)?;
895+
let duration: Option<f64> = start.map(|start| start.elapsed().as_secs_f64());
896+
let result = Statement::get_internal(
897+
&env,
898+
&row,
899+
&stmt.column_names,
900+
safe_ints,
901+
raw,
902+
pluck,
903+
duration,
904+
);
905+
stmt.stmt.reset();
906+
result
907+
})
908+
}
909+
874910
/// Runs a statement in blocking mode.
875911
#[napi]
876912
pub fn statement_run_sync(stmt: &Statement, params: Option<napi::JsUnknown>) -> Result<RunResult> {

wrapper.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"use strict";
22

3-
const { Database: NativeDb, databasePrepareSync, databaseSyncSync, databaseExecSync, statementRunSync, statementIterateSync, iteratorNextSync } = require("./index.js");
3+
const { Database: NativeDb, databasePrepareSync, databaseSyncSync, databaseExecSync, statementRunSync, statementGetSync, statementIterateSync, iteratorNextSync } = require("./index.js");
44
const SqliteError = require("./sqlite-error.js");
55
const Authorization = require("./auth");
66

@@ -276,7 +276,7 @@ class Statement {
276276
*/
277277
get(...bindParameters) {
278278
try {
279-
return this.stmt.get(...bindParameters);
279+
return statementGetSync(this.stmt, ...bindParameters);
280280
} catch (err) {
281281
throw convertError(err);
282282
}

0 commit comments

Comments
 (0)