Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/core/src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,23 @@ export class Database extends Service {
return sel._action('set', update).execute()
}

async setOne<K extends Keys<Tables>>(
table: K,
query: Query<Tables[K]>,
update: Row.Computed<Tables[K], Update<Tables[K]>>,
): Promise<Tables[K] | undefined> {
const sel = this.select(table, query, null)
if (typeof update === 'function') update = update(sel.row)
const primary = makeArray(sel.model.primary)
if (primary.some(key => key in update)) {
throw new TypeError(`cannot modify primary key`)
}

update = sel.model.format(update)
if (Object.keys(update).length === 0) return (await this.get(table, query))[0]
return sel._action('setOne', update).execute()
}

async remove<K extends Keys<Tables>>(table: K, query: Query<Tables[K]>): Promise<Driver.WriteResult> {
const sel = this.select(table, query, null)
return sel._action('remove').execute()
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ export abstract class Driver<T = any> {
abstract remove(sel: Selection.Mutable): Promise<Driver.WriteResult>
abstract create(sel: Selection.Mutable, data: any): Promise<any>
abstract upsert(sel: Selection.Mutable, data: any[], keys: string[]): Promise<Driver.WriteResult>
abstract setOne(sel: Selection.Mutable, data: Update): Promise<any>
abstract withTransaction(callback: (session?: any) => Promise<void>): Promise<void>

abstract getIndexes(table: string): Promise<Driver.Index[]>
abstract createIndex(table: string, index: Driver.Index): Promise<void>
abstract dropIndex(table: string, name: string): Promise<void>
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/selection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export interface Modifier {
}

namespace Executable {
export type Action = 'get' | 'set' | 'remove' | 'create' | 'upsert' | 'eval'
export type Action = 'get' | 'set' | 'remove' | 'create' | 'upsert' | 'eval' | 'setOne'

export interface Payload {
type: Action
Expand Down
9 changes: 9 additions & 0 deletions packages/memory/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ export class MemoryDriver extends Driver<MemoryDriver.Config> {
return { matched }
}

async setOne(sel: Selection.Mutable, data: {}) {
const { table, ref, query, model } = sel
const row = this.table(table).find(row => executeQuery(row, query, ref))
if (!row) return
executeUpdate(row, data, ref)
this.$save(table)
return model.parse(clone(row))
}

async remove(sel: Selection.Mutable) {
const { ref, query, table } = sel
const data = this.table(table)
Expand Down
30 changes: 30 additions & 0 deletions packages/mongo/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,36 @@ export class MongoDriver extends Driver<MongoDriver.Config> {
}
}

async setOne(sel: Selection.Mutable, update: {}) {
const { query, table, model } = sel
if (hasSubquery(sel.query) || Object.values(update).some(x => hasSubquery(x))) {
await this.set(sel, update)
const rows = await this.get(sel)
return rows[0]
}

const filter = this.transformQuery(sel, query, table)
if (!filter) return
const coll = this.db.collection(table)

const virtualKey = this.getVirtualKey(table)
const transformer = new Builder(this, Object.keys(sel.tables), virtualKey, '$' + tempKey + '.')
const $set = this.mapVirtualUpdate(update, virtualKey, (item, key) => transformer.toUpdateExpr(item, model.getType(key)))
const $unset = Object.entries($set)
.filter(([_, value]) => typeof value === 'object')
.map(([key, _]) => key)
const preset = Object.fromEntries(transformer.walkedKeys.map(key => [tempKey + '.' + key, '$' + key]))

const result = await coll.findOneAndUpdate(filter, [
...transformer.walkedKeys.length ? [{ $set: preset }] : [],
...$unset.length ? [{ $unset }] : [],
{ $set },
...transformer.walkedKeys.length ? [{ $unset: [tempKey] }] : [],
], { returnDocument: 'after', session: this.session })
if (!result) return
return this.builder.load(this.patchVirtual(table, result), model)
}

async remove(sel: Selection.Mutable) {
const { query, table } = sel
const filter = this.transformQuery(sel, query, table)
Expand Down
50 changes: 50 additions & 0 deletions packages/mysql/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,56 @@ RETURN UNHEX(REPLACE(u, '-', ''))`)
return { matched: result.affectedRows, modified: result.changedRows }
}

async setOne(sel: Selection.Mutable, data: {}) {
const { model, query, table, tables, ref } = sel
const builder = new MySQLBuilder(this, tables, this._compat)
const filter = builder.parseQuery(query)
const fields = model.availableFields()
if (filter === '0') return
const updateFields = [...new Set(Object.keys(data).map((key) => {
return Object.keys(fields).find(field => field === key || key.startsWith(field + '.'))!
}))]

const allFields = Object.keys(fields)
const varname = (f: string) => `@_${f.replace(/\./g, '_')}`

const setClauses = allFields.filter(f => f in fields).map((field) => {
if (updateFields.includes(field)) {
return `${escapeId(field)} = (${varname(field)} := ${builder.toUpdateExpr(data, field, fields[field], false)})`
} else {
return `${escapeId(field)} = (${varname(field)} := ${escapeId(field)})`
}
})

const selectExprs = allFields.map((field) => {
const alias = escapeId(field)
if (field in fields) {
return `${varname(field)} AS ${alias}`
}
const parent = Object.keys(fields).find(k => field.startsWith(k + '.'))
if (parent) {
const rest = field.slice(parent.length + 1)
return `json_extract(${varname(parent)}, '$.${rest.split('.').map((k: string) => `"${k}"`).join('.')}') AS ${alias}`
}
return `${varname(field)} AS ${alias}`
}).join(', ')

const sql = [
...builder.prequeries,
`UPDATE ${escapeId(table)} ${ref} SET ${setClauses.join(', ')} WHERE ${filter} LIMIT 1`,
`SELECT ${selectExprs}`,
].join('; ')
const results = await this.query(sql)
// UPDATE is always the last result with affectedRows (before SELECT)
const parts = Array.isArray(results[0]) || results[0]?.affectedRows !== undefined ? results : [results]
const updateResult = parts[parts.length === 1 ? 0 : parts.length - 2]
if (!updateResult || updateResult.affectedRows === 0) return
const selectResult = parts[parts.length - 1]
const row = Array.isArray(selectResult) ? selectResult[0] : selectResult
if (!row) return
return builder.load(row, model)
}

async remove(sel: Selection.Mutable) {
const { query, table, tables } = sel
const builder = new MySQLBuilder(this, tables, this._compat)
Expand Down
21 changes: 21 additions & 0 deletions packages/postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,27 @@ export class PostgresDriver extends Driver<PostgresDriver.Config> {
return { matched: result.length }
}

async setOne(sel: Selection.Mutable, data: {}) {
const { model, query, table, tables, ref } = sel
const builder = new PostgresBuilder(this, tables)
const filter = builder.parseQuery(query)
const fields = model.availableFields()
if (filter === '0') return
const updateFields = [...new Set(Object.keys(data).map((key) => {
return Object.keys(fields).find(field => field === key || key.startsWith(field + '.'))!
}))]

const update = updateFields.map((field) => {
const escaped = builder.escapeId(field)
return `${escaped} = ${builder.toUpdateExpr(data, field, fields[field], false)}`
}).join(', ')
const primaryFields = makeArray(model.primary).map(k => builder.escapeId(k))
const primaryTuple = primaryFields.length === 1 ? primaryFields[0] : `(${primaryFields.join(', ')})`
const subquery = `SELECT ${primaryFields.join(', ')} FROM ${builder.escapeId(table)} WHERE ${filter} LIMIT 1`
const result = await this.query(`UPDATE ${builder.escapeId(table)} ${ref} SET ${update} WHERE ${primaryTuple} = (${subquery}) RETURNING *`)
return result[0] ? builder.load(result[0], model) : undefined
}

async remove(sel: Selection.Mutable) {
const builder = new PostgresBuilder(this, sel.tables)
const query = builder.parseQuery(sel.query)
Expand Down
37 changes: 36 additions & 1 deletion packages/sqlite/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Binary, deepEqual, Dict, difference, isNullable, makeArray, mapValues } from 'cosmokit'
import { Binary, deepEqual, Dict, difference, isNullable, makeArray, mapValues, pick } from 'cosmokit'
import { bufferToUuid, Driver, Eval, executeUpdate, Field, getCell, hasSubquery, isEvalExpr, Selection, uuidToBuffer } from '@cordisjs/plugin-database'
import { escapeId } from '@cordisjs/sql-utils'
import type { DatabaseSync, StatementSync } from 'node:sqlite'
Expand Down Expand Up @@ -356,6 +356,41 @@ export class SQLiteDriver extends Driver<SQLiteDriver.Config> {
}
}

async setOne(sel: Selection.Mutable, data: {}) {
const { model, table, query } = sel
const { primary } = model, fields = model.availableFields()
const updateFields = [...new Set(Object.keys(data).map((key) => {
return Object.keys(fields).find(field => field === key || key.startsWith(field + '.'))!
}))]
const primaryFields = makeArray(primary)

if (query.$expr || hasSubquery(sel.query) || Object.values(data).some(x => hasSubquery(x))) {
const sel2 = this.database.select(table as never, query)
sel2.tables[sel.ref] = sel2.tables[sel2.ref]
delete sel2.tables[sel2.ref]
sel2.ref = sel.ref
const project = mapValues(data as any, (value, key) => () => (isEvalExpr(value) ? value : Eval.literal(value, model.getType(key))))
const rawUpsert = await sel2.project({
...project,
...Object.fromEntries(primaryFields.map(x => [x, () => Eval('', [sel.ref, x], sel2.model.getType(x)!)])),
}).limit(1).execute()
if (!rawUpsert.length) return
const row = rawUpsert[0]
const upsert = [{
...mapValues(data, (_, key) => getCell(row, key)),
...Object.fromEntries(primaryFields.map(x => [x, getCell(row, x)])),
}]
await this.database.upsert(table as never, upsert)
return (await this.database.get(table as never, pick(upsert[0], primaryFields as any) as any))[0]
} else {
const existing = await this.database.get(table as never, query)
const row = existing[0]
if (!row) return
this._update(sel, primaryFields, updateFields, data, row)
return (await this.database.get(table as never, pick(row, primaryFields as any) as any))[0]
}
}

_create(table: string, data: {}) {
const model = this.model(table)
data = this.sql.dump(data, model)
Expand Down
97 changes: 97 additions & 0 deletions packages/tests/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,67 @@ namespace ModelOperations {
await expect(database.eval('dtypes', row => $.array(row.text2))).to.eventually.contain('foo')
await expect(database.get('dtypes', row => $.eq(row.text2, $.literal('foo', 'string2')))).to.eventually.have.length(1)
})

it('setOne plain value', async () => {
const table = await setup(database, 'dtypes', dtypeTable)
table[0].text = 'updated'
const result = await database.setOne('dtypes', { id: 1 }, { text: 'updated' })
expect(result).to.not.be.undefined
expect(result!.text).to.equal('updated')
expect(result!.id).to.equal(table[0].id)
})

it('setOne dot notation + custom types', async () => {
const table = await setup(database, 'dtypes', dtypeTable)
const result = await database.setOne('dtypes', { id: 4 }, { 'object.embed.bool': false, 'object.embed.bigint': 999n })
expect(result).to.not.be.undefined
expect(result!.object!.embed!.bool).to.equal(false)
expect(result!.object!.embed!.bigint).to.equal(999n)
})

it('setOne expression update', async () => {
const table = await setup(database, 'dtypes', dtypeTable)
const row = table.find(r => r.id === 3)!
row.num! += 1
row.object!.json!.num = (row.object!.json!.num ?? 0) + 100
await expect(database.setOne('dtypes', 3, row => ({
num: $.add(row.num, 1),
'object.json.num': $.add($.ifNull(row.object.json.num, 0), 100),
}))).to.eventually.have.shape(row)
})

it('setOne subquery update', async () => {
const table = await setup(database, 'dtypes', dtypeTable)
const row = table.find(r => r.id === 2)!
row.text = row.text ? row.text : ''
await expect(database.setOne('dtypes', 2, row => ({
text: database.select('dtypes', r => $.eq(r.id, 1)).evaluate(r => $.max(r.text)),
}))).to.eventually.have.property('text')
})

it('setOne bigint + binary + bnum custom type', async () => {
const table = await setup(database, 'dtypes', dtypeTable)
table[10].bigint = 888n
table[10].bnum = 999
table[10].binary = toBinary('world')
await expect(database.setOne('dtypes', 11, {
bigint: 888n,
bnum: 999,
binary: toBinary('world'),
})).to.eventually.have.shape(table[10])
})

it('setOne no match returns undefined', async () => {
await setup(database, 'dtypes', dtypeTable)
await expect(database.setOne('dtypes', 99999, { text: 'nope' })).to.eventually.be.undefined
})

it('setOne preserves all field types', async () => {
const table = await setup(database, 'dtypes', dtypeTable)
const row = table.find(r => r.id === 5)!
row.text = 'all-types-ok'
await expect(database.setOne('dtypes', { id: 5 }, { text: 'all-types-ok' })).to.eventually.have.shape(row)
})
}

export const object = function ObjectFields(database: Database, options: ModelOptions = {}) {
Expand Down Expand Up @@ -674,6 +735,42 @@ namespace ModelOperations {
await expect(database.get('dobjects', row => $.eq(row.baz[0].nested.array[0], 1))).to.eventually.have.length(2)
})

it('setOne nested dot notation', async () => {
const table = await setup(database, 'dobjects', dobjectTable)
table[1].foo!.nested!.timestamp = new Date('2009/10/01 15:40:00')
table[1].foo!.nested!.binary = toBinary('boom')
await expect(database.setOne('dobjects', 2, {
'foo.nested.timestamp': new Date('2009/10/01 15:40:00'),
'foo.nested.binary': toBinary('boom'),
})).to.eventually.have.shape(table[1])
})

it('setOne nested object replace', async () => {
const table = await setup(database, 'dobjects', dobjectTable)
table[0].foo = { nested: { id: 1 } }
await expect(database.setOne('dobjects', 1, {
'foo.nested': { id: 1 },
})).to.eventually.have.nested.property('foo.nested.id', 1)
})

it('setOne expression on nested field', async () => {
const table = await setup(database, 'dobjects', dobjectTable)
const row = table.find(r => r.bar?.nested?.id !== undefined)!
row.bar!.nested!.text = 'expr-updated'
const result = await database.setOne('dobjects', row.id, {
'bar.nested.text': $.concat(row.bar!.nested!.text ?? '', ''),
})
expect(result).to.not.be.undefined
expect(result!.bar!.nested!.text).to.equal(row.bar!.nested!.text)
})

it('setOne recursive type', async () => {
const table = await setup(database, 'recurxs', [{ id: 1, y: { id: 2, x: { id: 3, y: { id: 4, x: { id: 5 } } } } }])
const row = table[0]
row.y!.id = 999
await expect(database.setOne('recurxs', 1, { y: { id: 999, x: { id: 3, y: { id: 4, x: { id: 5 } } } } })).to.eventually.have.shape(row)
})

nullableComparator && it('decode uuid', async () => {
await setup(database, 'dobjects', dobjectTable)
await expect(database.get('dobjects', row => $.eq(row.foo!.nested!.uuid!, $.literal(u1, 'uuid')))).to.eventually.have.length(1)
Expand Down
44 changes: 44 additions & 0 deletions packages/tests/src/update.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,50 @@ namespace OrmOperations {
})
}

export const setOne = function SetOne(database: Database) {
it('basic support', async () => {
const table = await setup(database, 'temp2', barTable)
const target = table.find(bar => bar.id === 2)!
target.text = 'updated'
await expect(database.setOne('temp2', { id: 2 }, { text: 'updated' })).to.eventually.have.shape(target)
target.text = 'updated2'
await expect(database.setOne('temp2', row => $.eq(row.id, 2), { text: 'updated2' })).to.eventually.have.shape(target)
})

it('no match returns undefined', async () => {
await setup(database, 'temp2', barTable)
await expect(database.setOne('temp2', { id: 9999 }, { text: 'nope' })).to.eventually.be.undefined
})

it('using expressions', async () => {
const table = await setup(database, 'temp2', barTable)
const target = table.find(bar => bar.id === 3)!
target.num = target.num! + 1
await expect(database.setOne('temp2', { id: 3 }, row => ({
num: $.add(row.num, 1),
}))).to.eventually.have.shape(target)
})

it('returns updated value not old', async () => {
await setup(database, 'temp2', barTable)
await expect(database.setOne('temp2', { id: 1 }, { text: 'new-value' })).to.eventually.have.property('text', 'new-value')
})

it('noop update returns existing row', async () => {
const table = await setup(database, 'temp2', barTable)
const target = table.find(bar => bar.id === 1)!
await expect(database.setOne('temp2', { id: 1 }, {})).to.eventually.have.shape(target)
})

it('advanced type', async () => {
const table = await setup(database, 'temp2', barTable)
const target = table.find(bar => bar.id === 1)!
target.binary = toBinary('world')
target.bigint = 1234567891011121314151617181920n
await expect(database.setOne('temp2', { id: 1 }, { binary: toBinary('world'), bigint: 1234567891011121314151617181920n })).to.eventually.have.shape(target)
})
}

export const upsert = function Upsert(database: Database) {
it('update existing records', async () => {
const table = await setup(database, 'temp2', barTable)
Expand Down