inlined libsql dialect, rewrote d1 to use generic sqlite

This commit is contained in:
dswbx
2025-06-25 09:35:47 +02:00
parent b2086c4da7
commit 57ae2f333c
16 changed files with 491 additions and 120 deletions

View File

@@ -1,42 +1,76 @@
/// <reference types="@cloudflare/workers-types" />
import { SqliteConnection } from "bknd/data";
import type { ConnQuery, ConnQueryResults } from "data/connection/Connection";
import { D1Dialect } from "kysely-d1";
import {
genericSqlite,
type GenericSqliteConnection,
} from "data/connection/sqlite/GenericSqliteConnection";
import type { QueryResult } from "kysely";
export type D1SqliteConnection = GenericSqliteConnection<D1Database>;
export type D1ConnectionConfig<DB extends D1Database | D1DatabaseSession = D1Database> = {
binding: DB;
};
export class D1Connection<
DB extends D1Database | D1DatabaseSession = D1Database,
> extends SqliteConnection<DB> {
override name = "sqlite-d1";
export function d1Sqlite(config: D1ConnectionConfig<D1Database>) {
const db = config.binding;
protected override readonly supported = {
batching: true,
softscans: false,
};
return genericSqlite(
"d1-sqlite",
db,
(utils) => {
const getStmt = (sql: string, parameters?: any[] | readonly any[]) =>
db.prepare(sql).bind(...(parameters || []));
constructor(private config: D1ConnectionConfig<DB>) {
super({
const mapResult = (res: D1Result<any>): QueryResult<any> => {
if (res.error) {
throw new Error(res.error);
}
const numAffectedRows =
res.meta.changes > 0 ? utils.parseBigInt(res.meta.changes) : undefined;
const insertId = res.meta.last_row_id
? utils.parseBigInt(res.meta.last_row_id)
: undefined;
return {
insertId,
numAffectedRows,
rows: res.results,
// @ts-ignore
meta: res.meta,
};
};
return {
db,
batch: async (stmts) => {
const res = await db.batch(
stmts.map(({ sql, parameters }) => {
return getStmt(sql, parameters);
}),
);
return res.map(mapResult);
},
query: utils.buildQueryFn({
all: async (sql, parameters) => {
const prep = getStmt(sql, parameters);
return mapResult(await prep.all()).rows;
},
run: async (sql, parameters) => {
const prep = getStmt(sql, parameters);
return mapResult(await prep.run());
},
}),
close: () => {},
};
},
{
supports: {
batching: true,
softscans: false,
},
excludeTables: ["_cf_KV", "_cf_METADATA"],
dialect: D1Dialect,
dialectArgs: [{ database: config.binding as D1Database }],
});
}
override async executeQueries<O extends ConnQuery[]>(...qbs: O): Promise<ConnQueryResults<O>> {
const compiled = this.getCompiled(...qbs);
const db = this.config.binding;
const res = await db.batch(
compiled.map(({ sql, parameters }) => {
return db.prepare(sql).bind(...parameters);
}),
);
return this.withTransformedRows(res, "results") as any;
}
},
);
}

View File

@@ -0,0 +1,54 @@
import { describe, test, expect } from "vitest";
import { viTestRunner } from "adapter/node/vitest";
import { connectionTestSuite } from "data/connection/connection-test-suite";
import { Miniflare } from "miniflare";
import { d1Sqlite } from "./D1Connection";
import { sql } from "kysely";
describe("d1Sqlite", async () => {
const mf = new Miniflare({
modules: true,
script: "export default { async fetch() { return new Response(null); } }",
d1Databases: ["DB"],
});
const binding = (await mf.getD1Database("DB")) as D1Database;
test("connection", async () => {
const conn = d1Sqlite({ binding });
expect(conn.supports("batching")).toBe(true);
expect(conn.supports("softscans")).toBe(false);
});
test("query details", async () => {
const conn = d1Sqlite({ binding });
const res = await conn.executeQuery(sql`select 1`.compile(conn.kysely));
expect(res.rows).toEqual([{ "1": 1 }]);
expect(res.numAffectedRows).toBe(undefined);
expect(res.insertId).toBe(undefined);
// @ts-expect-error
expect(res.meta.changed_db).toBe(false);
// @ts-expect-error
expect(res.meta.rows_read).toBe(0);
const batchResult = await conn.executeQueries(
sql`select 1`.compile(conn.kysely),
sql`select 2`.compile(conn.kysely),
);
// rewrite to get index
for (const [index, result] of batchResult.entries()) {
expect(result.rows).toEqual([{ [String(index + 1)]: index + 1 }]);
expect(result.numAffectedRows).toBe(undefined);
expect(result.insertId).toBe(undefined);
// @ts-expect-error
expect(result.meta.changed_db).toBe(false);
}
});
connectionTestSuite(viTestRunner, {
makeConnection: () => d1Sqlite({ binding }),
rawDialectDetails: [],
});
});

View File

@@ -0,0 +1,138 @@
import {
SqliteAdapter,
SqliteIntrospector,
SqliteQueryCompiler,
type CompiledQuery,
type DatabaseConnection,
type DatabaseIntrospector,
type Dialect,
type Driver,
type Kysely,
type QueryCompiler,
type QueryResult,
} from "kysely";
/**
* Config for the D1 dialect. Pass your D1 instance to this object that you bound in `wrangler.toml`.
*/
export interface D1DialectConfig {
database: D1Database;
}
/**
* D1 dialect that adds support for [Cloudflare D1][0] in [Kysely][1].
* The constructor takes the instance of your D1 database that you bound in `wrangler.toml`.
*
* ```typescript
* new D1Dialect({
* database: env.DB,
* })
* ```
*
* [0]: https://blog.cloudflare.com/introducing-d1/
* [1]: https://github.com/koskimas/kysely
*/
export class D1Dialect implements Dialect {
#config: D1DialectConfig;
constructor(config: D1DialectConfig) {
this.#config = config;
}
createAdapter() {
return new SqliteAdapter();
}
createDriver(): Driver {
return new D1Driver(this.#config);
}
createQueryCompiler(): QueryCompiler {
return new SqliteQueryCompiler();
}
createIntrospector(db: Kysely<any>): DatabaseIntrospector {
return new SqliteIntrospector(db);
}
}
class D1Driver implements Driver {
#config: D1DialectConfig;
constructor(config: D1DialectConfig) {
this.#config = config;
}
async init(): Promise<void> {}
async acquireConnection(): Promise<DatabaseConnection> {
return new D1Connection(this.#config);
}
async beginTransaction(conn: D1Connection): Promise<void> {
return await conn.beginTransaction();
}
async commitTransaction(conn: D1Connection): Promise<void> {
return await conn.commitTransaction();
}
async rollbackTransaction(conn: D1Connection): Promise<void> {
return await conn.rollbackTransaction();
}
async releaseConnection(_conn: D1Connection): Promise<void> {}
async destroy(): Promise<void> {}
}
class D1Connection implements DatabaseConnection {
#config: D1DialectConfig;
constructor(config: D1DialectConfig) {
this.#config = config;
}
async executeQuery<O>(compiledQuery: CompiledQuery): Promise<QueryResult<O>> {
const results = await this.#config.database
.prepare(compiledQuery.sql)
.bind(...compiledQuery.parameters)
.all();
if (results.error) {
throw new Error(results.error);
}
const numAffectedRows = results.meta.changes > 0 ? BigInt(results.meta.changes) : undefined;
return {
insertId:
results.meta.last_row_id === undefined || results.meta.last_row_id === null
? undefined
: BigInt(results.meta.last_row_id),
rows: (results?.results as O[]) || [],
numAffectedRows,
// @ts-ignore deprecated in kysely >= 0.23, keep for backward compatibility.
numUpdatedOrDeletedRows: numAffectedRows,
};
}
async beginTransaction() {
throw new Error("Transactions are not supported yet.");
}
async commitTransaction() {
throw new Error("Transactions are not supported yet.");
}
async rollbackTransaction() {
throw new Error("Transactions are not supported yet.");
}
// biome-ignore lint/correctness/useYield: <explanation>
async *streamQuery<O>(
_compiledQuery: CompiledQuery,
_chunkSize: number,
): AsyncIterableIterator<QueryResult<O>> {
throw new Error("D1 Driver does not support streaming");
}
}

View File

@@ -17,32 +17,41 @@ export function nodeSqlite(config?: NodeSqliteConnectionConfig | { url: string }
db = new DatabaseSync(":memory:");
}
return genericSqlite("node-sqlite", db, (utils) => {
const getStmt = (sql: string) => {
const stmt = db.prepare(sql);
//stmt.setReadBigInts(true);
return stmt;
};
return genericSqlite(
"node-sqlite",
db,
(utils) => {
const getStmt = (sql: string) => {
const stmt = db.prepare(sql);
//stmt.setReadBigInts(true);
return stmt;
};
return {
db,
query: utils.buildQueryFn({
all: (sql, parameters = []) => getStmt(sql).all(...parameters),
run: (sql, parameters = []) => {
const { changes, lastInsertRowid } = getStmt(sql).run(...parameters);
return {
insertId: utils.parseBigInt(lastInsertRowid),
numAffectedRows: utils.parseBigInt(changes),
};
return {
db,
query: utils.buildQueryFn({
all: (sql, parameters = []) => getStmt(sql).all(...parameters),
run: (sql, parameters = []) => {
const { changes, lastInsertRowid } = getStmt(sql).run(...parameters);
return {
insertId: utils.parseBigInt(lastInsertRowid),
numAffectedRows: utils.parseBigInt(changes),
};
},
}),
close: () => db.close(),
iterator: (isSelect, sql, parameters = []) => {
if (!isSelect) {
throw new Error("Only support select in stream()");
}
return getStmt(sql).iterate(...parameters) as any;
},
}),
close: () => db.close(),
iterator: (isSelect, sql, parameters = []) => {
if (!isSelect) {
throw new Error("Only support select in stream()");
}
return getStmt(sql).iterate(...parameters) as any;
};
},
{
supports: {
batching: false,
},
};
});
},
);
}

View File

@@ -1,5 +1,5 @@
import type { Connection } from "bknd/data";
import { libsql } from "../../data/connection/sqlite/LibsqlConnection";
import { libsql } from "../../data/connection/sqlite/libsql/LibsqlConnection";
export function sqlite(config: { url: string }): Connection {
return libsql(config);

View File

@@ -1,3 +0,0 @@
import type { Connection } from "bknd/data";
export type SqliteConnection = (config: { url: string }) => Connection;

View File

@@ -1,4 +1,4 @@
import type { KyselyPlugin } from "kysely";
import type { KyselyPlugin, QueryResult } from "kysely";
import {
type IGenericSqlite,
type OnCreateConnection,
@@ -8,11 +8,16 @@ import {
GenericSqliteDialect,
} from "kysely-generic-sqlite";
import { SqliteConnection } from "./SqliteConnection";
import type { Features } from "../Connection";
import type { ConnQuery, ConnQueryResults, Features } from "../Connection";
export type { IGenericSqlite };
export type TStatement = { sql: string; parameters?: any[] | readonly any[] };
export interface IGenericCustomSqlite<DB = unknown> extends IGenericSqlite<DB> {
batch?: (stmts: TStatement[]) => Promisable<QueryResult<any>[]>;
}
export type GenericSqliteConnectionConfig = {
name: string;
name?: string;
additionalPlugins?: KyselyPlugin[];
excludeTables?: string[];
onCreateConnection?: OnCreateConnection;
@@ -21,10 +26,11 @@ export type GenericSqliteConnectionConfig = {
export class GenericSqliteConnection<DB = unknown> extends SqliteConnection<DB> {
override name = "generic-sqlite";
#executor: IGenericCustomSqlite<DB> | undefined;
constructor(
db: DB,
executor: () => Promisable<IGenericSqlite>,
public db: DB,
private executor: () => Promisable<IGenericCustomSqlite<DB>>,
config?: GenericSqliteConnectionConfig,
) {
super({
@@ -39,18 +45,43 @@ export class GenericSqliteConnection<DB = unknown> extends SqliteConnection<DB>
}
if (config?.supports) {
for (const [key, value] of Object.entries(config.supports)) {
if (value) {
if (value !== undefined) {
this.supported[key] = value;
}
}
}
}
private async getExecutor() {
if (!this.#executor) {
this.#executor = await this.executor();
}
return this.#executor;
}
override async executeQueries<O extends ConnQuery[]>(...qbs: O): Promise<ConnQueryResults<O>> {
const executor = await this.getExecutor();
if (!executor.batch) {
console.warn("Batching is not supported by this database");
return super.executeQueries(...qbs);
}
const compiled = this.getCompiled(...qbs);
const stms: TStatement[] = compiled.map((q) => {
return {
sql: q.sql,
parameters: q.parameters as any[],
};
});
const results = await executor.batch(stms);
return this.withTransformedRows(results) as any;
}
}
export function genericSqlite<DB>(
name: string,
db: DB,
executor: (utils: typeof genericSqliteUtils) => Promisable<IGenericSqlite<DB>>,
executor: (utils: typeof genericSqliteUtils) => Promisable<IGenericCustomSqlite<DB>>,
config?: GenericSqliteConnectionConfig,
) {
return new GenericSqliteConnection(db, () => executor(genericSqliteUtils), {

View File

@@ -1,4 +1,4 @@
import { connectionTestSuite } from "../connection-test-suite";
import { connectionTestSuite } from "../../connection-test-suite";
import { LibsqlConnection } from "./LibsqlConnection";
import { bunTestRunner } from "adapter/bun/test";
import { describe } from "bun:test";

View File

@@ -1,23 +1,14 @@
import type { Client, Config, InStatement } from "@libsql/client";
import { createClient } from "libsql-stateless-easy";
import { LibsqlDialect } from "@libsql/kysely-libsql";
import { LibsqlDialect } from "./LibsqlDialect";
import { FilterNumericKeysPlugin } from "data/plugins/FilterNumericKeysPlugin";
import { type ConnQuery, type ConnQueryResults, SqliteConnection } from "bknd/data";
export const LIBSQL_PROTOCOLS = ["wss", "https", "libsql"] as const;
export type LibSqlCredentials = Config & {
protocol?: (typeof LIBSQL_PROTOCOLS)[number];
};
export type LibSqlCredentials = Config;
function getClient(clientOrCredentials: Client | LibSqlCredentials): Client {
if (clientOrCredentials && "url" in clientOrCredentials) {
let { url, authToken, protocol } = clientOrCredentials;
if (protocol && LIBSQL_PROTOCOLS.includes(protocol)) {
console.info("changing protocol to", protocol);
const [, rest] = url.split("://");
url = `${protocol}://${rest}`;
}
const { url, authToken } = clientOrCredentials;
return createClient({ url, authToken });
}

View File

@@ -0,0 +1,145 @@
import type { Client, Transaction, InValue } from "@libsql/client";
import {
SqliteAdapter,
SqliteIntrospector,
SqliteQueryCompiler,
type Kysely,
type Dialect,
type DialectAdapter,
type Driver,
type DatabaseIntrospector,
type QueryCompiler,
type TransactionSettings,
type DatabaseConnection,
type QueryResult,
type CompiledQuery,
} from "kysely";
export type LibsqlDialectConfig = {
client: Client;
};
export class LibsqlDialect implements Dialect {
#config: LibsqlDialectConfig;
constructor(config: LibsqlDialectConfig) {
this.#config = config;
}
createAdapter(): DialectAdapter {
return new SqliteAdapter();
}
createDriver(): Driver {
let client: Client;
let closeClient: boolean;
if ("client" in this.#config) {
client = this.#config.client;
closeClient = false;
} else {
throw new Error("Please specify either `client` or `url` in the LibsqlDialect config");
}
return new LibsqlDriver(client, closeClient);
}
createIntrospector(db: Kysely<any>): DatabaseIntrospector {
return new SqliteIntrospector(db);
}
createQueryCompiler(): QueryCompiler {
return new SqliteQueryCompiler();
}
}
export class LibsqlDriver implements Driver {
client: Client;
#closeClient: boolean;
constructor(client: Client, closeClient: boolean) {
this.client = client;
this.#closeClient = closeClient;
}
async init(): Promise<void> {}
async acquireConnection(): Promise<LibsqlConnection> {
return new LibsqlConnection(this.client);
}
async beginTransaction(
connection: LibsqlConnection,
_settings: TransactionSettings,
): Promise<void> {
await connection.beginTransaction();
}
async commitTransaction(connection: LibsqlConnection): Promise<void> {
await connection.commitTransaction();
}
async rollbackTransaction(connection: LibsqlConnection): Promise<void> {
await connection.rollbackTransaction();
}
async releaseConnection(_conn: LibsqlConnection): Promise<void> {}
async destroy(): Promise<void> {
if (this.#closeClient) {
this.client.close();
}
}
}
export class LibsqlConnection implements DatabaseConnection {
client: Client;
#transaction?: Transaction;
constructor(client: Client) {
this.client = client;
}
async executeQuery<R>(compiledQuery: CompiledQuery): Promise<QueryResult<R>> {
const target = this.#transaction ?? this.client;
const result = await target.execute({
sql: compiledQuery.sql,
args: compiledQuery.parameters as Array<InValue>,
});
return {
insertId: result.lastInsertRowid,
numAffectedRows: BigInt(result.rowsAffected),
rows: result.rows as Array<R>,
};
}
async beginTransaction() {
if (this.#transaction) {
throw new Error("Transaction already in progress");
}
this.#transaction = await this.client.transaction();
}
async commitTransaction() {
if (!this.#transaction) {
throw new Error("No transaction to commit");
}
await this.#transaction.commit();
this.#transaction = undefined;
}
async rollbackTransaction() {
if (!this.#transaction) {
throw new Error("No transaction to rollback");
}
await this.#transaction.rollback();
this.#transaction = undefined;
}
// biome-ignore lint/correctness/useYield: <explanation>
async *streamQuery<R>(
_compiledQuery: CompiledQuery,
_chunkSize: number,
): AsyncIterableIterator<QueryResult<R>> {
throw new Error("Libsql Driver does not support streaming yet");
}
}

View File

@@ -57,7 +57,7 @@ export class Repository<TBD extends object = DefaultDB, TB extends keyof TBD = a
}
}
getValidOptions(options?: RepoQuery): RepoQuery {
getValidOptions(options?: Partial<RepoQuery>): RepoQuery {
const entity = this.entity;
// @todo: if not cloned deep, it will keep references and error if multiple requests come in
const validated = {

View File

@@ -30,7 +30,7 @@ export * as DataPermissions from "./permissions";
export { MediaField, type MediaFieldConfig, type MediaItem } from "media/MediaField";
export { libsql } from "./connection/sqlite/LibsqlConnection";
export { libsql } from "./connection/sqlite/libsql/LibsqlConnection";
export {
genericSqlite,
genericSqliteUtils,

View File

@@ -150,4 +150,6 @@ export type RepoQueryIn = {
join?: string[];
where?: WhereQuery;
};
export type RepoQuery = s.StaticCoerced<typeof repoQuery>;
export type RepoQuery = s.StaticCoerced<typeof repoQuery> & {
sort: SortSchema;
};