Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,26 @@ fmt.Println(ub)
// UPDATE users SET level = level + ? WHERE id = ?
```

### Build `UPDATE ... FROM`

`UpdateBuilder.From` emits a `FROM` clause for PostgreSQL, SQLite, and SQLServer flavors (it is ignored by other flavors). When a CTE includes tables created with `CTETable`, those table names are emitted before any explicit `From(...)` tables.

```go
ub := PostgreSQL.NewUpdateBuilder()
ub.Update("users")
ub.Set(ub.Assign("name", "Huan Du"))
ub.From("people")
ub.Where("users.person_id = people.id")

sql, args := ub.Build()
fmt.Println(sql)
fmt.Println(args)

// Output:
// UPDATE users SET name = $1 FROM people WHERE users.person_id = people.id
// [Huan Du]
```

Refer to the [WhereClause](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#WhereClause) examples to learn its usage.

### Build `ORDER BY` clause
Expand Down
25 changes: 24 additions & 1 deletion update.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
updateMarkerAfterWith
updateMarkerAfterUpdate
updateMarkerAfterSet
updateMarkerAfterFrom
updateMarkerAfterWhere
updateMarkerAfterOrderBy
updateMarkerAfterLimit
Expand Down Expand Up @@ -71,6 +72,7 @@ type UpdateBuilder struct {
cteBuilder *CTEBuilder

tables []string
fromTables []string
assignments []string
orderByCols []string
order string
Expand Down Expand Up @@ -140,6 +142,13 @@ func (ub *UpdateBuilder) SetMore(assignment ...string) *UpdateBuilder {
return ub
}

// From sets table names of FROM in UPDATE.
func (ub *UpdateBuilder) From(table ...string) *UpdateBuilder {
ub.fromTables = table
ub.marker = updateMarkerAfterFrom
return ub
}

// Where adds expressions to the WHERE clause in UPDATE.
//
// Multiple calls to Where will join expressions with AND.
Expand Down Expand Up @@ -339,7 +348,7 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
buf.WriteStringsPrefixed("INSERTED.", ub.returning, ", ")
}

ub.injection.WriteTo(buf, insertMarkerAfterReturning)
ub.injection.WriteTo(buf, updateMarkerAfterReturning)
}

if flavor != MySQL {
Expand All @@ -354,6 +363,20 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
}
}

if flavor == PostgreSQL || flavor == SQLite || flavor == SQLServer {
if len(ub.fromTables) > 0 {

if ub.cteBuilder == nil || len(ub.cteBuilder.tableNamesForFrom()) == 0 {
buf.WriteLeadingString("FROM ")
} else {
buf.WriteString(", ")
}

buf.WriteStrings(ub.fromTables, ", ")
ub.injection.WriteTo(buf, updateMarkerAfterFrom)
}
}

if ub.WhereClause != nil {
ub.whereClauseProxy.WhereClause = ub.WhereClause
defer func() {
Expand Down
92 changes: 92 additions & 0 deletions update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,95 @@ func TestUpdateBuilderClone(t *testing.T) {
clone.Asc().Limit(5)
a.NotEqual(ub.String(), clone.String())
}

func TestUpdateBuilderFrom(t *testing.T) {
a := assert.New(t)
ub := NewUpdateBuilder()
ub.Update("user")
ub.Set(ub.Assign("name", "Huan Du"))
ub.From("person")
ub.Where(ub.Equal("id", 123))

sql, _ := ub.BuildWithFlavor(MySQL)
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)

sql, _ = ub.BuildWithFlavor(PostgreSQL)
a.Equal("UPDATE user SET name = $1 FROM person WHERE id = $2", sql)

sql, _ = ub.BuildWithFlavor(SQLite)
a.Equal("UPDATE user SET name = ? FROM person WHERE id = ?", sql)

sql, _ = ub.BuildWithFlavor(SQLServer)
a.Equal("UPDATE user SET name = @p1 FROM person WHERE id = @p2", sql)

sql, _ = ub.BuildWithFlavor(CQL)
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)

sql, _ = ub.BuildWithFlavor(ClickHouse)
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)

sql, _ = ub.BuildWithFlavor(Presto)
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)

// Test with no from
ub2 := NewUpdateBuilder()
ub2.Update("user")
ub2.Set(ub2.Assign("name", "Test"))
ub2.From()
ub2.Where(ub2.Equal("id", 1))

sql, _ = ub2.BuildWithFlavor(PostgreSQL)
a.Equal("UPDATE user SET name = $1 WHERE id = $2", sql)

// Test with multiple from tables
ub3 := NewUpdateBuilder()
ub3.Update("user")
ub3.Set(ub3.Assign("name", "Test"))
ub3.From("person", "company")
ub3.Where(ub3.Equal("id", 1))

sql, _ = ub3.BuildWithFlavor(PostgreSQL)
a.Equal("UPDATE user SET name = $1 FROM person, company WHERE id = $2", sql)

// Test chaining
ub5 := NewUpdateBuilder().Update("user").Set("status = 1").From("person").From("company")
sql, _ = ub5.BuildWithFlavor(PostgreSQL)
a.Equal("UPDATE user SET status = 1 FROM company", sql) // Last From call overwrites

// Test SQL injection after FROM
ub6 := NewUpdateBuilder()
ub6.Update("user")
ub6.Set(ub6.Assign("name", "Test"))
ub6.From("person")
ub6.SQL("/* comment after from */")
ub6.Where(ub6.Equal("id", 1))

sql, _ = ub6.BuildWithFlavor(PostgreSQL)
a.Equal("UPDATE user SET name = $1 FROM person /* comment after from */ WHERE id = $2", sql)

// Test with CTE (WITH clause)
cte := With(CTETable("temp_user").As(Select("id").From("active_users")))
ub7 := cte.Update("user")
ub7.Set(ub7.Assign("status", "active"))
ub7.From("person")
ub7.Where("user.id IN (SELECT id FROM temp_user)")

sql, _ = ub7.BuildWithFlavor(PostgreSQL)
a.Equal("WITH temp_user AS (SELECT id FROM active_users) UPDATE user SET status = $1 FROM temp_user, person WHERE user.id IN (SELECT id FROM temp_user)", sql)

// Test with SQLServer Returning
ub8 := ub.Clone().Returning("id", "name")
sql, _ = ub8.BuildWithFlavor(SQLServer)
a.Equal("UPDATE user SET name = @p1 OUTPUT INSERTED.id, INSERTED.name FROM person WHERE id = @p2", sql)

// Test with SQL injection after WHERE
ub9 := NewUpdateBuilder()
ub9.Update("user")
ub9.Set(ub9.Assign("name", "Test"))
ub9.From("person")
ub9.Where("user.id = person.id")
ub9.SQL("/* comment after where */")

sql, _ = ub9.BuildWithFlavor(PostgreSQL)
a.Equal("UPDATE user SET name = $1 FROM person WHERE user.id = person.id /* comment after where */", sql)
}