Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ ktlint_code_style = intellij_idea
# Allow trailing commas in function parameters and arguments
ij_kotlin_allow_trailing_comma = true
ij_kotlin_allow_trailing_comma_on_call_site = true
# Disable class-signature rule (added in ktlint 14). The default forces multi-line
# class headers when the supertype call has a body (e.g. `class T : StringSpec({...})`).
# That would re-indent every Kotest spec across 300+ lines for pure cosmetics —
# not worth it. Existing `max_line_length = 120` already keeps headers readable.
ktlint_standard_class-signature = disabled

[*.md]
trim_trailing_whitespace = false
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ plugins {
kotlin("jvm") version "2.3.21" apply false
kotlin("plugin.serialization") version "2.3.21" apply false
id("org.jetbrains.kotlinx.kover") version "0.9.8" apply false
id("org.jlleitschuh.gradle.ktlint") version "12.1.2" apply false
id("org.jlleitschuh.gradle.ktlint") version "14.2.0" apply false
id("io.gitlab.arturbosch.detekt") version "1.23.8" apply false
}

Expand Down
32 changes: 15 additions & 17 deletions mosaic-cli/src/main/kotlin/dev/mosaic/cli/CreateCommand.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,22 @@ internal object CreateCommand {
return 0
}

private fun buildInitializer(name: String, dim: Int, seed: Long, parsed: Args): Initializer {
return when (name) {
"uniform" -> {
val bound = parsed.optionalFloat("--bound")
if (bound != null) Initializer.uniform(bound, seed) else Initializer.uniformDefault(seed)
}
"xavier" -> Initializer.xavier(fanIn = dim, fanOut = dim, seed = seed)
"he" -> Initializer.he(fanIn = dim, seed = seed)
"zeros" -> Initializer.zeros()
"constant" -> {
val value = parsed.optionalFloat("--value")
?: throw UsageError("Initializer 'constant' requires --value")
Initializer.constant(value)
}
else -> throw UsageError(
"Unknown initializer '$name'. Valid: uniform, xavier, he, zeros, constant",
)
private fun buildInitializer(name: String, dim: Int, seed: Long, parsed: Args): Initializer = when (name) {
"uniform" -> {
val bound = parsed.optionalFloat("--bound")
if (bound != null) Initializer.uniform(bound, seed) else Initializer.uniformDefault(seed)
}
"xavier" -> Initializer.xavier(fanIn = dim, fanOut = dim, seed = seed)
"he" -> Initializer.he(fanIn = dim, seed = seed)
"zeros" -> Initializer.zeros()
"constant" -> {
val value = parsed.optionalFloat("--value")
?: throw UsageError("Initializer 'constant' requires --value")
Initializer.constant(value)
}
else -> throw UsageError(
"Unknown initializer '$name'. Valid: uniform, xavier, he, zeros, constant",
)
}

private fun help(): String = """
Expand Down
4 changes: 1 addition & 3 deletions mosaic-core/src/main/kotlin/dev/mosaic/EmbeddingTable.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ public class EmbeddingTable internal constructor(
}

/** Returns copies of the rows at the given [ids], in order. */
public fun get(ids: IntArray): Array<FloatArray> {
return Array(ids.size) { idx -> get(ids[idx]) }
}
public fun get(ids: IntArray): Array<FloatArray> = Array(ids.size) { idx -> get(ids[idx]) }

/** Writes [vector] into the row at [id]. The source array is copied; the caller may mutate it freely afterwards. */
public fun set(id: Int, vector: FloatArray) {
Expand Down
12 changes: 5 additions & 7 deletions mosaic-core/src/test/kotlin/dev/mosaic/EmbeddingTableTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,8 @@ class EmbeddingTableTest : StringSpec({
}
})

private fun embeddingsWithSeed(vocabSize: Int, embeddingDim: Int, seed: Long): EmbeddingTable {
return EmbeddingTable.create(
vocabSize = vocabSize,
embeddingDim = embeddingDim,
initializer = Initializer.uniform(bound = 1f, seed = seed),
)
}
private fun embeddingsWithSeed(vocabSize: Int, embeddingDim: Int, seed: Long): EmbeddingTable = EmbeddingTable.create(
vocabSize = vocabSize,
embeddingDim = embeddingDim,
initializer = Initializer.uniform(bound = 1f, seed = seed),
)