Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,18 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
.load()

if (jdbcReaderConfig.correctDecimalsInSchema || jdbcReaderConfig.correctDecimalsFixPrecision) {
if (isDataQuery) {
df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
}

JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach(schema =>
JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach { schema =>
df = spark
.read
.format("jdbc")
.options(connectionOptions)
.option("customSchema", schema)
.load()
)
}

if (isDataQuery) {
df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
}
Comment on lines 164 to +176
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Move sanitization out of the decimal-correction branch.

sanitizeDfColumns() now runs only when one of the decimal-correction flags is enabled. With both flags off, data queries stop sanitizing entirely, so the duplicate-safe renames and original_name metadata from this PR never get applied.

💡 Proposed fix
     if (jdbcReaderConfig.correctDecimalsInSchema || jdbcReaderConfig.correctDecimalsFixPrecision) {
       JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach { schema =>
         df = spark
           .read
           .format("jdbc")
           .options(connectionOptions)
           .option("customSchema", schema)
           .load()
       }
-
-      if (isDataQuery) {
-        df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
-      }
+    }
+
+    if (isDataQuery) {
+      df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (jdbcReaderConfig.correctDecimalsInSchema || jdbcReaderConfig.correctDecimalsFixPrecision) {
if (isDataQuery) {
df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
}
JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach(schema =>
JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach { schema =>
df = spark
.read
.format("jdbc")
.options(connectionOptions)
.option("customSchema", schema)
.load()
)
}
if (isDataQuery) {
df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
}
if (jdbcReaderConfig.correctDecimalsInSchema || jdbcReaderConfig.correctDecimalsFixPrecision) {
JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach { schema =>
df = spark
.read
.format("jdbc")
.options(connectionOptions)
.option("customSchema", schema)
.load()
}
}
if (isDataQuery) {
df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters)
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@pramen/core/src/main/scala/za/co/absa/pramen/core/reader/TableReaderJdbc.scala`
around lines 164 - 176, The sanitization call (SparkUtils.sanitizeDfColumns) is
currently gated by the decimal-correction if-block so it only runs when
jdbcReaderConfig.correctDecimalsInSchema || correctDecimalsFixPrecision is true;
move the sanitization so it always runs for data queries. Specifically, keep the
existing decimal-correction logic using
JdbcSparkUtils.getCorrectedDecimalsSchema and the reload into df, but remove the
SparkUtils.sanitizeDfColumns call from inside that if and instead call
SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters) once after
the decimal-correction block when isDataQuery is true (so sanitize runs
regardless of the decimal flags).

}

if (jdbcReaderConfig.saveTimestampsAsDates) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@ import za.co.absa.pramen.core.utils.SparkMaster.Databricks
import java.io.ByteArrayOutputStream
import java.time.format.DateTimeFormatter
import java.time.{Instant, LocalDate}
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.reflect.runtime.universe._
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

object SparkUtils {
private val log = LoggerFactory.getLogger(this.getClass)
private val charVarcharTypePattern = """(?i)\s*(char|varchar)\((\d+)\)\s*""".r
private val charVarcharLengthPattern = """(?:char|varchar)\((\d+)\)""".r

val MAX_LENGTH_METADATA_KEY = "maxLength"
val CHAR_VARCHAR_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING"
val COMMENT_METADATA_KEY = "comment"
val ORIGINAL_NAME_METADATA_KEY = "original_name"

// This seems to be limitation for multiple catalogs, like Glue and Hive.
val MAX_COMMENT_LENGTH = 255
Expand Down Expand Up @@ -103,6 +109,9 @@ object SparkUtils {
* @param characters A set of characters considered special
*/
def sanitizeDfColumns(df: DataFrame, characters: String): DataFrame = {
val namesLowercase = new mutable.HashSet[String]
namesLowercase ++= df.schema.fields.map(_.name.toLowerCase)

def replaceSpecialChars(s: String): String = {
s.map(c => if (characters.contains(c)) '_' else c)
}
Expand All @@ -121,14 +130,39 @@ object SparkUtils {
}.distinct.length == 1
}

def getUniqueName(baseName: String): String = {
var uniqueName = baseName
var counter = 1
while (namesLowercase.contains(uniqueName.toLowerCase)) {
uniqueName = s"${baseName}_$counter"
counter += 1
}
uniqueName
}

val hasTablePrefix = hasUniformTablePrefix(df.schema.fields)

val fieldsToSelect = df.schema.fields.map(field => {
val srcName = field.name
val trgName = replaceSpecialChars(if(hasTablePrefix) removeTablePrefix(srcName.trim) else srcName.trim)
val trgName = replaceSpecialChars(if (hasTablePrefix) removeTablePrefix(srcName.trim) else srcName.trim)
if (srcName != trgName) {
log.info(s"Renamed column: '$srcName' -> '$trgName''")
col(s"`$srcName`").as(trgName)
val uniqueName = if (namesLowercase.contains(trgName.toLowerCase)) {
val newName = getUniqueName(trgName)
namesLowercase.remove(srcName.toLowerCase)
namesLowercase.add(newName.toLowerCase)
newName
} else {
namesLowercase.remove(srcName.toLowerCase)
namesLowercase.add(trgName.toLowerCase)
trgName
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

log.info(s"Renamed column: '$srcName' -> '$uniqueName'")
val newMetadata = new MetadataBuilder()
.withMetadata(field.metadata)
.putString(ORIGINAL_NAME_METADATA_KEY, srcName)
.build()
col(s"`$srcName`").as(uniqueName, newMetadata)
} else {
col(s"`$srcName`")
}
Expand Down Expand Up @@ -325,10 +359,7 @@ object SparkUtils {
def transformPrimitive(dataType: DataType, field: StructField): DataType = {
dataType match {
case _: StringType =>
getLengthFromMetadata(field.metadata) match {
case Some(n) => VarcharType(n)
case None => StringType
}
getStringTypeFromMetadata(field.metadata)
case _ =>
dataType
}
Expand Down Expand Up @@ -371,11 +402,46 @@ object SparkUtils {
try1
}
try2.getOrElse(None)
} else if (metadata.contains(CHAR_VARCHAR_METADATA_KEY)) {
val typeString = metadata.getString(CHAR_VARCHAR_METADATA_KEY).toLowerCase
try {
charVarcharLengthPattern.findFirstMatchIn(typeString).map(_.group(1).toInt)
} catch {
case NonFatal(_) => None
}
} else {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
None
}
}

/**
* Extracts a string-based data type (CharType, VarcharType, or StringType) from field metadata.
*
* First checks for the presence of character/varchar type metadata key. If found, parses the metadata
* value using a pattern to determine if it represents a CHAR or VARCHAR type with a specific length.
* If the metadata key is not present, attempts to extract a length value from metadata and creates
* a VarcharType if successful. Falls back to StringType if no specific type information is found
* or if the metadata cannot be parsed.
*
* @param metadata the metadata object containing type information for a field
* @return the resolved string-based DataType, which can be CharType, VarcharType, or StringType
*/
def getStringTypeFromMetadata(metadata: Metadata): DataType = {
if (metadata.contains(CHAR_VARCHAR_METADATA_KEY)) {
metadata.getString(CHAR_VARCHAR_METADATA_KEY) match {
case charVarcharTypePattern(kind, len) if kind.equalsIgnoreCase("char") =>
CharType(len.toInt)
case charVarcharTypePattern(_, len) =>
VarcharType(len.toInt)
case _ =>
getLengthFromMetadata(metadata).map(VarcharType.apply).getOrElse(StringType)
}
} else {
getLengthFromMetadata(metadata).map(VarcharType.apply).getOrElse(StringType)
}
}


/**
* Sanitizes a comment for Hive DDL. Ideally this should be done by Spark, but because there are meny versions
* of Hive and other catalogs, it is sometimes hard to have an general solution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,38 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture
assert(stripLineEndings(actual) == stripLineEndings(expected))
}

"add suffixes when there is a name collision" in {
val expected =
"""[ { "a_a" : "A", "a_a_1" : 1, "b_b_1" : 1, "b_b" : 1, "c_c" : 1, "c_c_1" : 1, "c_c_2" : 1}, { "a_a" : "B", "a_a_1" : 2, "b_b_1" : 2, "b_b" : 2, "c_c" : 2, "c_c_1" : 2, "c_c_2" : 2}, { "a_a" : "C", "a_a_1" : 3, "b_b_1" : 3, "b_b" : 3, "c_c" : 3, "c_c_1" : 3, "c_c_2" : 3} ]"""

val expectedMetadataField1 =
"""{"original_name":"a:a","maxLength":10}"""

val expectedMetadataField2 =
"""{"original_name":"a a"}"""

val df = List(("A", 1, 1, 1, 1, 1, 1), ("B", 2, 2, 2, 2, 2, 2), ("C", 3, 3, 3, 3, 3, 3)).toDF("a:a", "a a", "b<b", "b_b", "c_c", "c c", "c:c")

val field1 = df.schema.fields.head
val field1WIthMetadata = field1.copy(metadata = new MetadataBuilder().putLong("maxLength", 10).build())
val newSchema = StructType(field1WIthMetadata +: df.schema.fields.drop(1))
val df1 = spark.createDataFrame(df.rdd, newSchema)

val actualDf = sanitizeDfColumns(df1, " :<>")

val actual = convertDataFrameToPrettyJSON(actualDf).stripMargin.linesIterator.mkString("").trim

assert(stripLineEndings(actual) == stripLineEndings(expected))

val metadataField1 = actualDf.schema.fields.head.metadata.json
val metadataField2 = actualDf.schema.fields(1).metadata.json

assert(stripLineEndings(actual) == stripLineEndings(expected))
assert(metadataField1 == expectedMetadataField1)
assert(metadataField2 == expectedMetadataField2)

}

"rename columns that start with .tbl" in {
val expected =
"""[ { "a_a" : "A", "b" : 1}, { "a_a" : "B", "b" : 2}, { "a_a" : "C", "b" : 3} ]"""
Expand All @@ -178,6 +210,47 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture
assert(stripLineEndings(actual) == stripLineEndings(expected))
}

"retain metadata and add the original_name metadata key" in {
val expected =
"""[ { "a_a" : "A", "b" : 1}, { "a_a" : "B", "b" : 2}, { "a_a" : "C", "b" : 3} ]"""
val expectedMetadataField1 =
"""{"original_name":"tbl.a a","maxLength":10}"""

val expectedMetadataField2 =
"""{"original_name":"tbl.b","comment":"Test"}"""

val df = List(("A", 1), ("B", 2), ("C", 3)).toDF("tbl.a a", "tbl.b")
val field1 = df.schema.fields.head
val field2 = df.schema.fields(1)
val field1WIthMetadata = field1.copy(metadata = new MetadataBuilder().putLong("maxLength", 10).build())
val field2WIthMetadata = field2.copy(metadata = new MetadataBuilder().putString("comment", "Test").build())
val newSchema = StructType(Seq(field1WIthMetadata, field2WIthMetadata))
val df1 = spark.createDataFrame(df.rdd, newSchema)

val actualDf = sanitizeDfColumns(df1, " ")
val metadataField1 = actualDf.schema.fields.head.metadata.json
val metadataField2 = actualDf.schema.fields(1).metadata.json

val actual = convertDataFrameToPrettyJSON(actualDf).stripMargin.linesIterator.mkString("").trim

assert(stripLineEndings(actual) == stripLineEndings(expected))
assert(metadataField1 == expectedMetadataField1)
assert(metadataField2 == expectedMetadataField2)
}

"handle case-insensitive name collisions" in {
val df = List(("A", 1, 2)).toDF("Test_Column", "test column", "TEST:COLUMN")

val actualDf = sanitizeDfColumns(df, " :")

// All three should produce unique names despite case differences
val colNames = actualDf.schema.fields.map(_.name)
assert(colNames.distinct.length == 3, "All columns should have unique names")
assert(colNames.head == "Test_Column")
assert(colNames(1) == "test_column_1")
assert(colNames(2) == "TEST_COLUMN_2")
}

"convert schema from Spark to Json and back should produce the same schema" in {
val testCaseSchema = StructType(
Array(
Expand Down Expand Up @@ -557,6 +630,24 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture
assert(len.contains(10))
}

"return length for string type 2" in {
val metadata = new MetadataBuilder
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(11)")

val len = SparkUtils.getLengthFromMetadata(metadata.build())

assert(len.contains(11))
}

"return length for string type 3" in {
val metadata = new MetadataBuilder
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "CHAR(12)")

val len = SparkUtils.getLengthFromMetadata(metadata.build())

assert(len.contains(12))
}

"return None for wrong type" in {
val metadata = new MetadataBuilder
metadata.putString(MAX_LENGTH_METADATA_KEY, "abc")
Expand Down Expand Up @@ -584,6 +675,108 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture
}
}

"getStringTypeFromMetadata" should {
"return varchar type for long metadata type" in {
val metadata = new MetadataBuilder
metadata.putLong(MAX_LENGTH_METADATA_KEY, 10L)

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == VarcharType(10))
}

"return varchar type for string metadata type" in {
val metadata = new MetadataBuilder
metadata.putString(MAX_LENGTH_METADATA_KEY, "10")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == VarcharType(10))
}

"return varchar type from CHAR_VARCHAR_METADATA_KEY with varchar" in {
val metadata = new MetadataBuilder
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(11)")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == VarcharType(11))
}

"return char type from CHAR_VARCHAR_METADATA_KEY with char" in {
val metadata = new MetadataBuilder
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "CHAR(12)")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == CharType(12))
}

"return varchar type and ignore case in CHAR_VARCHAR_METADATA_KEY" in {
val metadata = new MetadataBuilder
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "VARCHAR(15)")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == VarcharType(15))
}

"return None for wrong type in MAX_LENGTH_METADATA_KEY" in {
val metadata = new MetadataBuilder
metadata.putString(MAX_LENGTH_METADATA_KEY, "abc")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == StringType)
}

"return None for double type in MAX_LENGTH_METADATA_KEY" in {
val metadata = new MetadataBuilder
metadata.putDouble(MAX_LENGTH_METADATA_KEY, 12.25)

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == StringType)
}

"return None if metadata not specified" in {
val metadata = new MetadataBuilder

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == StringType)
}

"prioritize CHAR_VARCHAR_METADATA_KEY over MAX_LENGTH_METADATA_KEY" in {
val metadata = new MetadataBuilder
metadata.putLong(MAX_LENGTH_METADATA_KEY, 10L)
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(20)")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == VarcharType(20))
}

"handle invalid CHAR_VARCHAR_METADATA_KEY" in {
val metadata = new MetadataBuilder
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(10.1)")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == StringType)
}

"handle malformed CHAR_VARCHAR_METADATA_KEY and fallback to MAX_LENGTH_METADATA_KEY" in {
val metadata = new MetadataBuilder
metadata.putLong(MAX_LENGTH_METADATA_KEY, 10L)
metadata.putString(CHAR_VARCHAR_METADATA_KEY, "invalid_format")

val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build())

assert(stringType == VarcharType(10))
}
}

"removeNestedMetadata" should {
"remove metadata, but only from nested fields" in {
val metadata1 = new MetadataBuilder().putLong("maxLength", 5).putString("comment", "Employee name").build()
Expand Down
Loading