diff --git a/packages/core/src/database.ts b/packages/core/src/database.ts index ed12daa6..ccc18bbb 100644 --- a/packages/core/src/database.ts +++ b/packages/core/src/database.ts @@ -513,6 +513,23 @@ export class Database extends Service { return sel._action('set', update).execute() } + async setOne>( + table: K, + query: Query, + update: Row.Computed>, + ): Promise { + const sel = this.select(table, query, null) + if (typeof update === 'function') update = update(sel.row) + const primary = makeArray(sel.model.primary) + if (primary.some(key => key in update)) { + throw new TypeError(`cannot modify primary key`) + } + + update = sel.model.format(update) + if (Object.keys(update).length === 0) return (await this.get(table, query))[0] + return sel._action('setOne', update).execute() + } + async remove>(table: K, query: Query): Promise { const sel = this.select(table, query, null) return sel._action('remove').execute() diff --git a/packages/core/src/driver.ts b/packages/core/src/driver.ts index 8c8cac0e..c8c4af72 100644 --- a/packages/core/src/driver.ts +++ b/packages/core/src/driver.ts @@ -69,7 +69,9 @@ export abstract class Driver { abstract remove(sel: Selection.Mutable): Promise abstract create(sel: Selection.Mutable, data: any): Promise abstract upsert(sel: Selection.Mutable, data: any[], keys: string[]): Promise + abstract setOne(sel: Selection.Mutable, data: Update): Promise abstract withTransaction(callback: (session?: any) => Promise): Promise + abstract getIndexes(table: string): Promise abstract createIndex(table: string, index: Driver.Index): Promise abstract dropIndex(table: string, name: string): Promise diff --git a/packages/core/src/selection.ts b/packages/core/src/selection.ts index 63bbe049..55248aa9 100644 --- a/packages/core/src/selection.ts +++ b/packages/core/src/selection.ts @@ -27,7 +27,7 @@ export interface Modifier { } namespace Executable { - export type Action = 'get' | 'set' | 'remove' | 'create' | 'upsert' | 'eval' + export type Action = 'get' | 'set' | 'remove' | 'create' | 'upsert' | 'eval' | 'setOne' export interface Payload { type: Action diff --git a/packages/memory/src/index.ts b/packages/memory/src/index.ts index 1c65a2d1..fd1a019e 100644 --- a/packages/memory/src/index.ts +++ b/packages/memory/src/index.ts @@ -132,6 +132,15 @@ export class MemoryDriver extends Driver { return { matched } } + async setOne(sel: Selection.Mutable, data: {}) { + const { table, ref, query, model } = sel + const row = this.table(table).find(row => executeQuery(row, query, ref)) + if (!row) return + executeUpdate(row, data, ref) + this.$save(table) + return model.parse(clone(row)) + } + async remove(sel: Selection.Mutable) { const { ref, query, table } = sel const data = this.table(table) diff --git a/packages/mongo/src/index.ts b/packages/mongo/src/index.ts index 2351b678..caf854cc 100644 --- a/packages/mongo/src/index.ts +++ b/packages/mongo/src/index.ts @@ -397,6 +397,36 @@ export class MongoDriver extends Driver { } } + async setOne(sel: Selection.Mutable, update: {}) { + const { query, table, model } = sel + if (hasSubquery(sel.query) || Object.values(update).some(x => hasSubquery(x))) { + await this.set(sel, update) + const rows = await this.get(sel) + return rows[0] + } + + const filter = this.transformQuery(sel, query, table) + if (!filter) return + const coll = this.db.collection(table) + + const virtualKey = this.getVirtualKey(table) + const transformer = new Builder(this, Object.keys(sel.tables), virtualKey, '$' + tempKey + '.') + const $set = this.mapVirtualUpdate(update, virtualKey, (item, key) => transformer.toUpdateExpr(item, model.getType(key))) + const $unset = Object.entries($set) + .filter(([_, value]) => typeof value === 'object') + .map(([key, _]) => key) + const preset = Object.fromEntries(transformer.walkedKeys.map(key => [tempKey + '.' + key, '$' + key])) + + const result = await coll.findOneAndUpdate(filter, [ + ...transformer.walkedKeys.length ? [{ $set: preset }] : [], + ...$unset.length ? [{ $unset }] : [], + { $set }, + ...transformer.walkedKeys.length ? [{ $unset: [tempKey] }] : [], + ], { returnDocument: 'after', session: this.session }) + if (!result) return + return this.builder.load(this.patchVirtual(table, result), model) + } + async remove(sel: Selection.Mutable) { const { query, table } = sel const filter = this.transformQuery(sel, query, table) diff --git a/packages/mysql/src/index.ts b/packages/mysql/src/index.ts index 63dda6b7..dc00d5b3 100644 --- a/packages/mysql/src/index.ts +++ b/packages/mysql/src/index.ts @@ -452,6 +452,56 @@ RETURN UNHEX(REPLACE(u, '-', ''))`) return { matched: result.affectedRows, modified: result.changedRows } } + async setOne(sel: Selection.Mutable, data: {}) { + const { model, query, table, tables, ref } = sel + const builder = new MySQLBuilder(this, tables, this._compat) + const filter = builder.parseQuery(query) + const fields = model.availableFields() + if (filter === '0') return + const updateFields = [...new Set(Object.keys(data).map((key) => { + return Object.keys(fields).find(field => field === key || key.startsWith(field + '.'))! + }))] + + const allFields = Object.keys(fields) + const varname = (f: string) => `@_${f.replace(/\./g, '_')}` + + const setClauses = allFields.filter(f => f in fields).map((field) => { + if (updateFields.includes(field)) { + return `${escapeId(field)} = (${varname(field)} := ${builder.toUpdateExpr(data, field, fields[field], false)})` + } else { + return `${escapeId(field)} = (${varname(field)} := ${escapeId(field)})` + } + }) + + const selectExprs = allFields.map((field) => { + const alias = escapeId(field) + if (field in fields) { + return `${varname(field)} AS ${alias}` + } + const parent = Object.keys(fields).find(k => field.startsWith(k + '.')) + if (parent) { + const rest = field.slice(parent.length + 1) + return `json_extract(${varname(parent)}, '$.${rest.split('.').map((k: string) => `"${k}"`).join('.')}') AS ${alias}` + } + return `${varname(field)} AS ${alias}` + }).join(', ') + + const sql = [ + ...builder.prequeries, + `UPDATE ${escapeId(table)} ${ref} SET ${setClauses.join(', ')} WHERE ${filter} LIMIT 1`, + `SELECT ${selectExprs}`, + ].join('; ') + const results = await this.query(sql) + // UPDATE is always the last result with affectedRows (before SELECT) + const parts = Array.isArray(results[0]) || results[0]?.affectedRows !== undefined ? results : [results] + const updateResult = parts[parts.length === 1 ? 0 : parts.length - 2] + if (!updateResult || updateResult.affectedRows === 0) return + const selectResult = parts[parts.length - 1] + const row = Array.isArray(selectResult) ? selectResult[0] : selectResult + if (!row) return + return builder.load(row, model) + } + async remove(sel: Selection.Mutable) { const { query, table, tables } = sel const builder = new MySQLBuilder(this, tables, this._compat) diff --git a/packages/postgres/src/index.ts b/packages/postgres/src/index.ts index a6697267..167e4bfc 100644 --- a/packages/postgres/src/index.ts +++ b/packages/postgres/src/index.ts @@ -314,6 +314,27 @@ export class PostgresDriver extends Driver { return { matched: result.length } } + async setOne(sel: Selection.Mutable, data: {}) { + const { model, query, table, tables, ref } = sel + const builder = new PostgresBuilder(this, tables) + const filter = builder.parseQuery(query) + const fields = model.availableFields() + if (filter === '0') return + const updateFields = [...new Set(Object.keys(data).map((key) => { + return Object.keys(fields).find(field => field === key || key.startsWith(field + '.'))! + }))] + + const update = updateFields.map((field) => { + const escaped = builder.escapeId(field) + return `${escaped} = ${builder.toUpdateExpr(data, field, fields[field], false)}` + }).join(', ') + const primaryFields = makeArray(model.primary).map(k => builder.escapeId(k)) + const primaryTuple = primaryFields.length === 1 ? primaryFields[0] : `(${primaryFields.join(', ')})` + const subquery = `SELECT ${primaryFields.join(', ')} FROM ${builder.escapeId(table)} WHERE ${filter} LIMIT 1` + const result = await this.query(`UPDATE ${builder.escapeId(table)} ${ref} SET ${update} WHERE ${primaryTuple} = (${subquery}) RETURNING *`) + return result[0] ? builder.load(result[0], model) : undefined + } + async remove(sel: Selection.Mutable) { const builder = new PostgresBuilder(this, sel.tables) const query = builder.parseQuery(sel.query) diff --git a/packages/sqlite/src/index.ts b/packages/sqlite/src/index.ts index 3ec194a0..dae4d13b 100644 --- a/packages/sqlite/src/index.ts +++ b/packages/sqlite/src/index.ts @@ -1,4 +1,4 @@ -import { Binary, deepEqual, Dict, difference, isNullable, makeArray, mapValues } from 'cosmokit' +import { Binary, deepEqual, Dict, difference, isNullable, makeArray, mapValues, pick } from 'cosmokit' import { bufferToUuid, Driver, Eval, executeUpdate, Field, getCell, hasSubquery, isEvalExpr, Selection, uuidToBuffer } from '@cordisjs/plugin-database' import { escapeId } from '@cordisjs/sql-utils' import type { DatabaseSync, StatementSync } from 'node:sqlite' @@ -356,6 +356,41 @@ export class SQLiteDriver extends Driver { } } + async setOne(sel: Selection.Mutable, data: {}) { + const { model, table, query } = sel + const { primary } = model, fields = model.availableFields() + const updateFields = [...new Set(Object.keys(data).map((key) => { + return Object.keys(fields).find(field => field === key || key.startsWith(field + '.'))! + }))] + const primaryFields = makeArray(primary) + + if (query.$expr || hasSubquery(sel.query) || Object.values(data).some(x => hasSubquery(x))) { + const sel2 = this.database.select(table as never, query) + sel2.tables[sel.ref] = sel2.tables[sel2.ref] + delete sel2.tables[sel2.ref] + sel2.ref = sel.ref + const project = mapValues(data as any, (value, key) => () => (isEvalExpr(value) ? value : Eval.literal(value, model.getType(key)))) + const rawUpsert = await sel2.project({ + ...project, + ...Object.fromEntries(primaryFields.map(x => [x, () => Eval('', [sel.ref, x], sel2.model.getType(x)!)])), + }).limit(1).execute() + if (!rawUpsert.length) return + const row = rawUpsert[0] + const upsert = [{ + ...mapValues(data, (_, key) => getCell(row, key)), + ...Object.fromEntries(primaryFields.map(x => [x, getCell(row, x)])), + }] + await this.database.upsert(table as never, upsert) + return (await this.database.get(table as never, pick(upsert[0], primaryFields as any) as any))[0] + } else { + const existing = await this.database.get(table as never, query) + const row = existing[0] + if (!row) return + this._update(sel, primaryFields, updateFields, data, row) + return (await this.database.get(table as never, pick(row, primaryFields as any) as any))[0] + } + } + _create(table: string, data: {}) { const model = this.model(table) data = this.sql.dump(data, model) diff --git a/packages/tests/src/model.ts b/packages/tests/src/model.ts index a3940655..1dcd8728 100644 --- a/packages/tests/src/model.ts +++ b/packages/tests/src/model.ts @@ -542,6 +542,67 @@ namespace ModelOperations { await expect(database.eval('dtypes', row => $.array(row.text2))).to.eventually.contain('foo') await expect(database.get('dtypes', row => $.eq(row.text2, $.literal('foo', 'string2')))).to.eventually.have.length(1) }) + + it('setOne plain value', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + table[0].text = 'updated' + const result = await database.setOne('dtypes', { id: 1 }, { text: 'updated' }) + expect(result).to.not.be.undefined + expect(result!.text).to.equal('updated') + expect(result!.id).to.equal(table[0].id) + }) + + it('setOne dot notation + custom types', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + const result = await database.setOne('dtypes', { id: 4 }, { 'object.embed.bool': false, 'object.embed.bigint': 999n }) + expect(result).to.not.be.undefined + expect(result!.object!.embed!.bool).to.equal(false) + expect(result!.object!.embed!.bigint).to.equal(999n) + }) + + it('setOne expression update', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + const row = table.find(r => r.id === 3)! + row.num! += 1 + row.object!.json!.num = (row.object!.json!.num ?? 0) + 100 + await expect(database.setOne('dtypes', 3, row => ({ + num: $.add(row.num, 1), + 'object.json.num': $.add($.ifNull(row.object.json.num, 0), 100), + }))).to.eventually.have.shape(row) + }) + + it('setOne subquery update', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + const row = table.find(r => r.id === 2)! + row.text = row.text ? row.text : '' + await expect(database.setOne('dtypes', 2, row => ({ + text: database.select('dtypes', r => $.eq(r.id, 1)).evaluate(r => $.max(r.text)), + }))).to.eventually.have.property('text') + }) + + it('setOne bigint + binary + bnum custom type', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + table[10].bigint = 888n + table[10].bnum = 999 + table[10].binary = toBinary('world') + await expect(database.setOne('dtypes', 11, { + bigint: 888n, + bnum: 999, + binary: toBinary('world'), + })).to.eventually.have.shape(table[10]) + }) + + it('setOne no match returns undefined', async () => { + await setup(database, 'dtypes', dtypeTable) + await expect(database.setOne('dtypes', 99999, { text: 'nope' })).to.eventually.be.undefined + }) + + it('setOne preserves all field types', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + const row = table.find(r => r.id === 5)! + row.text = 'all-types-ok' + await expect(database.setOne('dtypes', { id: 5 }, { text: 'all-types-ok' })).to.eventually.have.shape(row) + }) } export const object = function ObjectFields(database: Database, options: ModelOptions = {}) { @@ -674,6 +735,42 @@ namespace ModelOperations { await expect(database.get('dobjects', row => $.eq(row.baz[0].nested.array[0], 1))).to.eventually.have.length(2) }) + it('setOne nested dot notation', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + table[1].foo!.nested!.timestamp = new Date('2009/10/01 15:40:00') + table[1].foo!.nested!.binary = toBinary('boom') + await expect(database.setOne('dobjects', 2, { + 'foo.nested.timestamp': new Date('2009/10/01 15:40:00'), + 'foo.nested.binary': toBinary('boom'), + })).to.eventually.have.shape(table[1]) + }) + + it('setOne nested object replace', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + table[0].foo = { nested: { id: 1 } } + await expect(database.setOne('dobjects', 1, { + 'foo.nested': { id: 1 }, + })).to.eventually.have.nested.property('foo.nested.id', 1) + }) + + it('setOne expression on nested field', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + const row = table.find(r => r.bar?.nested?.id !== undefined)! + row.bar!.nested!.text = 'expr-updated' + const result = await database.setOne('dobjects', row.id, { + 'bar.nested.text': $.concat(row.bar!.nested!.text ?? '', ''), + }) + expect(result).to.not.be.undefined + expect(result!.bar!.nested!.text).to.equal(row.bar!.nested!.text) + }) + + it('setOne recursive type', async () => { + const table = await setup(database, 'recurxs', [{ id: 1, y: { id: 2, x: { id: 3, y: { id: 4, x: { id: 5 } } } } }]) + const row = table[0] + row.y!.id = 999 + await expect(database.setOne('recurxs', 1, { y: { id: 999, x: { id: 3, y: { id: 4, x: { id: 5 } } } } })).to.eventually.have.shape(row) + }) + nullableComparator && it('decode uuid', async () => { await setup(database, 'dobjects', dobjectTable) await expect(database.get('dobjects', row => $.eq(row.foo!.nested!.uuid!, $.literal(u1, 'uuid')))).to.eventually.have.length(1) diff --git a/packages/tests/src/update.ts b/packages/tests/src/update.ts index 8cf1b580..b2c64dae 100644 --- a/packages/tests/src/update.ts +++ b/packages/tests/src/update.ts @@ -241,6 +241,50 @@ namespace OrmOperations { }) } + export const setOne = function SetOne(database: Database) { + it('basic support', async () => { + const table = await setup(database, 'temp2', barTable) + const target = table.find(bar => bar.id === 2)! + target.text = 'updated' + await expect(database.setOne('temp2', { id: 2 }, { text: 'updated' })).to.eventually.have.shape(target) + target.text = 'updated2' + await expect(database.setOne('temp2', row => $.eq(row.id, 2), { text: 'updated2' })).to.eventually.have.shape(target) + }) + + it('no match returns undefined', async () => { + await setup(database, 'temp2', barTable) + await expect(database.setOne('temp2', { id: 9999 }, { text: 'nope' })).to.eventually.be.undefined + }) + + it('using expressions', async () => { + const table = await setup(database, 'temp2', barTable) + const target = table.find(bar => bar.id === 3)! + target.num = target.num! + 1 + await expect(database.setOne('temp2', { id: 3 }, row => ({ + num: $.add(row.num, 1), + }))).to.eventually.have.shape(target) + }) + + it('returns updated value not old', async () => { + await setup(database, 'temp2', barTable) + await expect(database.setOne('temp2', { id: 1 }, { text: 'new-value' })).to.eventually.have.property('text', 'new-value') + }) + + it('noop update returns existing row', async () => { + const table = await setup(database, 'temp2', barTable) + const target = table.find(bar => bar.id === 1)! + await expect(database.setOne('temp2', { id: 1 }, {})).to.eventually.have.shape(target) + }) + + it('advanced type', async () => { + const table = await setup(database, 'temp2', barTable) + const target = table.find(bar => bar.id === 1)! + target.binary = toBinary('world') + target.bigint = 1234567891011121314151617181920n + await expect(database.setOne('temp2', { id: 1 }, { binary: toBinary('world'), bigint: 1234567891011121314151617181920n })).to.eventually.have.shape(target) + }) + } + export const upsert = function Upsert(database: Database) { it('update existing records', async () => { const table = await setup(database, 'temp2', barTable)