diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/TableReaderJdbc.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/TableReaderJdbc.scala index 1eafde63..b57af0fd 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/TableReaderJdbc.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/reader/TableReaderJdbc.scala @@ -162,18 +162,18 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig, .load() if (jdbcReaderConfig.correctDecimalsInSchema || jdbcReaderConfig.correctDecimalsFixPrecision) { - if (isDataQuery) { - df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters) - } - - JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach(schema => + JdbcSparkUtils.getCorrectedDecimalsSchema(df, jdbcReaderConfig.correctDecimalsFixPrecision).foreach { schema => df = spark .read .format("jdbc") .options(connectionOptions) .option("customSchema", schema) .load() - ) + } + + if (isDataQuery) { + df = SparkUtils.sanitizeDfColumns(df, jdbcReaderConfig.specialCharacters) + } } if (jdbcReaderConfig.saveTimestampsAsDates) { diff --git a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala index 99f3a45d..de5b3e21 100644 --- a/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala +++ b/pramen/core/src/main/scala/za/co/absa/pramen/core/utils/SparkUtils.scala @@ -32,15 +32,21 @@ import za.co.absa.pramen.core.utils.SparkMaster.Databricks import java.io.ByteArrayOutputStream import java.time.format.DateTimeFormatter import java.time.{Instant, LocalDate} +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.reflect.runtime.universe._ +import scala.util.control.NonFatal import scala.util.{Failure, Success, Try} object SparkUtils { private val log = LoggerFactory.getLogger(this.getClass) + private val charVarcharTypePattern = """(?i)\s*(char|varchar)\((\d+)\)\s*""".r + private val charVarcharLengthPattern = """(?:char|varchar)\((\d+)\)""".r val MAX_LENGTH_METADATA_KEY = "maxLength" + val CHAR_VARCHAR_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" val COMMENT_METADATA_KEY = "comment" + val ORIGINAL_NAME_METADATA_KEY = "original_name" // This seems to be limitation for multiple catalogs, like Glue and Hive. val MAX_COMMENT_LENGTH = 255 @@ -103,6 +109,9 @@ object SparkUtils { * @param characters A set of characters considered special */ def sanitizeDfColumns(df: DataFrame, characters: String): DataFrame = { + val namesLowercase = new mutable.HashSet[String] + namesLowercase ++= df.schema.fields.map(_.name.toLowerCase) + def replaceSpecialChars(s: String): String = { s.map(c => if (characters.contains(c)) '_' else c) } @@ -121,14 +130,39 @@ object SparkUtils { }.distinct.length == 1 } + def getUniqueName(baseName: String): String = { + var uniqueName = baseName + var counter = 1 + while (namesLowercase.contains(uniqueName.toLowerCase)) { + uniqueName = s"${baseName}_$counter" + counter += 1 + } + uniqueName + } + val hasTablePrefix = hasUniformTablePrefix(df.schema.fields) val fieldsToSelect = df.schema.fields.map(field => { val srcName = field.name - val trgName = replaceSpecialChars(if(hasTablePrefix) removeTablePrefix(srcName.trim) else srcName.trim) + val trgName = replaceSpecialChars(if (hasTablePrefix) removeTablePrefix(srcName.trim) else srcName.trim) if (srcName != trgName) { - log.info(s"Renamed column: '$srcName' -> '$trgName''") - col(s"`$srcName`").as(trgName) + val uniqueName = if (namesLowercase.contains(trgName.toLowerCase)) { + val newName = getUniqueName(trgName) + namesLowercase.remove(srcName.toLowerCase) + namesLowercase.add(newName.toLowerCase) + newName + } else { + namesLowercase.remove(srcName.toLowerCase) + namesLowercase.add(trgName.toLowerCase) + trgName + } + + log.info(s"Renamed column: '$srcName' -> '$uniqueName'") + val newMetadata = new MetadataBuilder() + .withMetadata(field.metadata) + .putString(ORIGINAL_NAME_METADATA_KEY, srcName) + .build() + col(s"`$srcName`").as(uniqueName, newMetadata) } else { col(s"`$srcName`") } @@ -325,10 +359,7 @@ object SparkUtils { def transformPrimitive(dataType: DataType, field: StructField): DataType = { dataType match { case _: StringType => - getLengthFromMetadata(field.metadata) match { - case Some(n) => VarcharType(n) - case None => StringType - } + getStringTypeFromMetadata(field.metadata) case _ => dataType } @@ -371,11 +402,46 @@ object SparkUtils { try1 } try2.getOrElse(None) + } else if (metadata.contains(CHAR_VARCHAR_METADATA_KEY)) { + val typeString = metadata.getString(CHAR_VARCHAR_METADATA_KEY).toLowerCase + try { + charVarcharLengthPattern.findFirstMatchIn(typeString).map(_.group(1).toInt) + } catch { + case NonFatal(_) => None + } } else { None } } + /** + * Extracts a string-based data type (CharType, VarcharType, or StringType) from field metadata. + * + * First checks for the presence of character/varchar type metadata key. If found, parses the metadata + * value using a pattern to determine if it represents a CHAR or VARCHAR type with a specific length. + * If the metadata key is not present, attempts to extract a length value from metadata and creates + * a VarcharType if successful. Falls back to StringType if no specific type information is found + * or if the metadata cannot be parsed. + * + * @param metadata the metadata object containing type information for a field + * @return the resolved string-based DataType, which can be CharType, VarcharType, or StringType + */ + def getStringTypeFromMetadata(metadata: Metadata): DataType = { + if (metadata.contains(CHAR_VARCHAR_METADATA_KEY)) { + metadata.getString(CHAR_VARCHAR_METADATA_KEY) match { + case charVarcharTypePattern(kind, len) if kind.equalsIgnoreCase("char") => + CharType(len.toInt) + case charVarcharTypePattern(_, len) => + VarcharType(len.toInt) + case _ => + getLengthFromMetadata(metadata).map(VarcharType.apply).getOrElse(StringType) + } + } else { + getLengthFromMetadata(metadata).map(VarcharType.apply).getOrElse(StringType) + } + } + + /** * Sanitizes a comment for Hive DDL. Ideally this should be done by Spark, but because there are meny versions * of Hive and other catalogs, it is sometimes hard to have an general solution. diff --git a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala index 9f28a67a..22a23bb3 100644 --- a/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala +++ b/pramen/core/src/test/scala/za/co/absa/pramen/core/tests/utils/SparkUtilsSuite.scala @@ -152,6 +152,38 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture assert(stripLineEndings(actual) == stripLineEndings(expected)) } + "add suffixes when there is a name collision" in { + val expected = + """[ { "a_a" : "A", "a_a_1" : 1, "b_b_1" : 1, "b_b" : 1, "c_c" : 1, "c_c_1" : 1, "c_c_2" : 1}, { "a_a" : "B", "a_a_1" : 2, "b_b_1" : 2, "b_b" : 2, "c_c" : 2, "c_c_1" : 2, "c_c_2" : 2}, { "a_a" : "C", "a_a_1" : 3, "b_b_1" : 3, "b_b" : 3, "c_c" : 3, "c_c_1" : 3, "c_c_2" : 3} ]""" + + val expectedMetadataField1 = + """{"original_name":"a:a","maxLength":10}""" + + val expectedMetadataField2 = + """{"original_name":"a a"}""" + + val df = List(("A", 1, 1, 1, 1, 1, 1), ("B", 2, 2, 2, 2, 2, 2), ("C", 3, 3, 3, 3, 3, 3)).toDF("a:a", "a a", "b") + + val actual = convertDataFrameToPrettyJSON(actualDf).stripMargin.linesIterator.mkString("").trim + + assert(stripLineEndings(actual) == stripLineEndings(expected)) + + val metadataField1 = actualDf.schema.fields.head.metadata.json + val metadataField2 = actualDf.schema.fields(1).metadata.json + + assert(stripLineEndings(actual) == stripLineEndings(expected)) + assert(metadataField1 == expectedMetadataField1) + assert(metadataField2 == expectedMetadataField2) + + } + "rename columns that start with .tbl" in { val expected = """[ { "a_a" : "A", "b" : 1}, { "a_a" : "B", "b" : 2}, { "a_a" : "C", "b" : 3} ]""" @@ -178,6 +210,47 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture assert(stripLineEndings(actual) == stripLineEndings(expected)) } + "retain metadata and add the original_name metadata key" in { + val expected = + """[ { "a_a" : "A", "b" : 1}, { "a_a" : "B", "b" : 2}, { "a_a" : "C", "b" : 3} ]""" + val expectedMetadataField1 = + """{"original_name":"tbl.a a","maxLength":10}""" + + val expectedMetadataField2 = + """{"original_name":"tbl.b","comment":"Test"}""" + + val df = List(("A", 1), ("B", 2), ("C", 3)).toDF("tbl.a a", "tbl.b") + val field1 = df.schema.fields.head + val field2 = df.schema.fields(1) + val field1WIthMetadata = field1.copy(metadata = new MetadataBuilder().putLong("maxLength", 10).build()) + val field2WIthMetadata = field2.copy(metadata = new MetadataBuilder().putString("comment", "Test").build()) + val newSchema = StructType(Seq(field1WIthMetadata, field2WIthMetadata)) + val df1 = spark.createDataFrame(df.rdd, newSchema) + + val actualDf = sanitizeDfColumns(df1, " ") + val metadataField1 = actualDf.schema.fields.head.metadata.json + val metadataField2 = actualDf.schema.fields(1).metadata.json + + val actual = convertDataFrameToPrettyJSON(actualDf).stripMargin.linesIterator.mkString("").trim + + assert(stripLineEndings(actual) == stripLineEndings(expected)) + assert(metadataField1 == expectedMetadataField1) + assert(metadataField2 == expectedMetadataField2) + } + + "handle case-insensitive name collisions" in { + val df = List(("A", 1, 2)).toDF("Test_Column", "test column", "TEST:COLUMN") + + val actualDf = sanitizeDfColumns(df, " :") + + // All three should produce unique names despite case differences + val colNames = actualDf.schema.fields.map(_.name) + assert(colNames.distinct.length == 3, "All columns should have unique names") + assert(colNames.head == "Test_Column") + assert(colNames(1) == "test_column_1") + assert(colNames(2) == "TEST_COLUMN_2") + } + "convert schema from Spark to Json and back should produce the same schema" in { val testCaseSchema = StructType( Array( @@ -557,6 +630,24 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture assert(len.contains(10)) } + "return length for string type 2" in { + val metadata = new MetadataBuilder + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(11)") + + val len = SparkUtils.getLengthFromMetadata(metadata.build()) + + assert(len.contains(11)) + } + + "return length for string type 3" in { + val metadata = new MetadataBuilder + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "CHAR(12)") + + val len = SparkUtils.getLengthFromMetadata(metadata.build()) + + assert(len.contains(12)) + } + "return None for wrong type" in { val metadata = new MetadataBuilder metadata.putString(MAX_LENGTH_METADATA_KEY, "abc") @@ -584,6 +675,108 @@ class SparkUtilsSuite extends AnyWordSpec with SparkTestBase with TempDirFixture } } + "getStringTypeFromMetadata" should { + "return varchar type for long metadata type" in { + val metadata = new MetadataBuilder + metadata.putLong(MAX_LENGTH_METADATA_KEY, 10L) + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == VarcharType(10)) + } + + "return varchar type for string metadata type" in { + val metadata = new MetadataBuilder + metadata.putString(MAX_LENGTH_METADATA_KEY, "10") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == VarcharType(10)) + } + + "return varchar type from CHAR_VARCHAR_METADATA_KEY with varchar" in { + val metadata = new MetadataBuilder + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(11)") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == VarcharType(11)) + } + + "return char type from CHAR_VARCHAR_METADATA_KEY with char" in { + val metadata = new MetadataBuilder + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "CHAR(12)") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == CharType(12)) + } + + "return varchar type and ignore case in CHAR_VARCHAR_METADATA_KEY" in { + val metadata = new MetadataBuilder + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "VARCHAR(15)") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == VarcharType(15)) + } + + "return None for wrong type in MAX_LENGTH_METADATA_KEY" in { + val metadata = new MetadataBuilder + metadata.putString(MAX_LENGTH_METADATA_KEY, "abc") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == StringType) + } + + "return None for double type in MAX_LENGTH_METADATA_KEY" in { + val metadata = new MetadataBuilder + metadata.putDouble(MAX_LENGTH_METADATA_KEY, 12.25) + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == StringType) + } + + "return None if metadata not specified" in { + val metadata = new MetadataBuilder + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == StringType) + } + + "prioritize CHAR_VARCHAR_METADATA_KEY over MAX_LENGTH_METADATA_KEY" in { + val metadata = new MetadataBuilder + metadata.putLong(MAX_LENGTH_METADATA_KEY, 10L) + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(20)") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == VarcharType(20)) + } + + "handle invalid CHAR_VARCHAR_METADATA_KEY" in { + val metadata = new MetadataBuilder + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "varchar(10.1)") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == StringType) + } + + "handle malformed CHAR_VARCHAR_METADATA_KEY and fallback to MAX_LENGTH_METADATA_KEY" in { + val metadata = new MetadataBuilder + metadata.putLong(MAX_LENGTH_METADATA_KEY, 10L) + metadata.putString(CHAR_VARCHAR_METADATA_KEY, "invalid_format") + + val stringType = SparkUtils.getStringTypeFromMetadata(metadata.build()) + + assert(stringType == VarcharType(10)) + } + } + "removeNestedMetadata" should { "remove metadata, but only from nested fields" in { val metadata1 = new MetadataBuilder().putLong("maxLength", 5).putString("comment", "Employee name").build()