Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ interface WithEncoding<Session, Stmt, ResultRow> {
val startingStatementIndex: StartingIndex get() = StartingIndex.Zero // default for JDBC is 1 so this needs to be overrideable
val startingResultRowIndex: StartingIndex get() = StartingIndex.Zero // default for JDBC is 1 so this needs to be overrideable

/**
* Strictly for error reporting purposes. Some databases have types that are relevant to encoding/decoding that the encoder should show when an error occurs.
*/
fun dbTypeIsRelevant(): Boolean = true

fun createEncodingContext(session: Session, stmt: Stmt) =
EncodingContext(session, stmt, encodingConfig.timezone)
EncodingContext(session, stmt, encodingConfig.timezone, startingStatementIndex, dbTypeIsRelevant())
fun createDecodingContext(session: Session, row: ResultRow, debugInfo: QueryDebugInfo?) =
DecodingContext(session, row, encodingConfig.timezone, startingResultRowIndex, catchRethrowColumnInfoExtractError { extractColumnInfo(row) }, debugInfo)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package io.exoquery.controller

class ControllerError(message: String, cause: Throwable? = null) : Exception(message, cause)
open class ControllerError(message: String, cause: Throwable? = null) : Exception(message, cause) {
class DecodingError(message: String, cause: Throwable? = null) : ControllerError(message, cause)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@ open class DecoderAny<T: Any, Session, Row>(
): SqlDecoder<Session, Row, T>() {
override fun isNullable(): Boolean = false
override fun decode(ctx: DecodingContext<Session, Row>, index: Int): T {
val value = f(ctx, index)
val value =
try {
f(ctx, index)
} catch (ex: Exception) {
val msg =
"Error decoding column at index $index for type ${type.simpleName}" +
(ctx.columnInfoSafe(index)?.let { " (${it.name}:${it.type})" } ?: "")
throw ControllerError.DecodingError(msg, ex)
}
if (value == null && !isNullable()) {
val msg =
"Got null value for non-nullable column of type ${type.simpleName} at index $index" +
(ctx.columnInfo(index-1)?.let { " (${it.name}:${it.type})" } ?: "")
(ctx.columnInfoSafe(index)?.let { " (${it.name}:${it.type})" } ?: "")

throw NullPointerException(msg)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ open class EncoderAny<T: Any, TypeId: Any, Session, Stmt>(
else
setNull(index, ctx.stmt, jdbcType)
} catch (e: Throwable) {
throw EncodingException("Error encoding ${type} value: $value at index: $index (whose jdbc-type: ${jdbcType})", e)
val jdbcTypeInfo = if (ctx.dbTypeIsRelevant) " (whose database-type: ${jdbcType})" else ""
throw EncodingException("Error encoding ${type} value: $value at (${ctx.startingIndex.description}) index: $index${jdbcTypeInfo}", e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ import kotlinx.datetime.TimeZone

data class QueryDebugInfo(val query: String)

open class EncodingContext<Session, Stmt>(open val session: Session, open val stmt: Stmt, open val timeZone: TimeZone)
open class EncodingContext<Session, Stmt>(
open val session: Session,
open val stmt: Stmt,
open val timeZone: TimeZone,
open val startingIndex: StartingIndex,
open val dbTypeIsRelevant: Boolean
)
open class DecodingContext<Session, Row>(
open val session: Session,
open val row: Row,
open val timeZone: TimeZone,
open val startingIndex: StartingIndex,
val columnInfos: List<ColumnInfo>?,
open val columnInfos: List<ColumnInfo>?,
open val debugInfo: QueryDebugInfo?
) {
/**
Expand All @@ -19,4 +25,11 @@ open class DecodingContext<Session, Row>(
*/
fun columnInfo(index: Int): ColumnInfo? =
columnInfos?.get(index-startingIndex.value)

fun columnInfoSafe(index: Int): ColumnInfo? =
try {
columnInfo(index)
} catch (ex: IndexOutOfBoundsException) {
null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class PreparedStatementElementEncoder<Session, Stmt>(
serializer.serialize(this, value)
else
(serializer as? KSerializer<T>)?.nullable?.serialize(this, value)
?: throw IllegalArgumentException("Cannot encode null value at index ${index} with the descriptor ${desc}. The serializer ${serializer} could not be converted into a KSerializer.")
?: throw IllegalArgumentException("cannot encode null value at (${ctx.startingIndex.value}) index ${index} with the descriptor ${desc}. The serializer ${serializer} could not be converted into a KSerializer.")
}

else -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ fun SerialDescriptor.verifyColumns(columns: List<ColumnInfo>): Unit {

sealed interface StartingIndex {
val value: Int
val description: String

object Zero: StartingIndex { override val value: Int = 0 }
object One: StartingIndex { override val value: Int = 1 }
object Zero: StartingIndex { override val value: Int = 0; override val description: String = "zero-based" }
object One: StartingIndex { override val value: Int = 1; override val description: String = "one-based" }
}

sealed interface RowDecoderType {
Expand Down Expand Up @@ -120,15 +121,15 @@ class RowDecoder<Session, Row> private constructor(
RowDecoder(ctx, this.serializersModule, initialRowIndex, api, decoders, type, json, debugMode, endCallback)

// helper to get column names
fun colName(index: Int) = ctx.columnInfos?.get(index)?.name ?: "<UNKNOWN>"
fun colName(index: Int) = ctx.columnInfoSafe(index)?.name ?: "<UNKNOWN>"

var rowIndex: Int = initialRowIndex
var classIndex: Int = 0

fun nextRowIndex(desc: SerialDescriptor, descIndex: Int, note: String = ""): Int {
val curr = rowIndex
if (debugMode) {
println("[RowDecoder] Get Row ${ctx.columnInfo(rowIndex)}, Index: ${curr} - (${descIndex}) ${desc.getElementDescriptor(descIndex)} - (Preview:${api.preview(rowIndex, ctx.row)})" + (if (note != "") " - ${note}" else ""))
println("[RowDecoder] Get Row ${ctx.columnInfoSafe(rowIndex)}, Index: ${curr} - (${descIndex}) ${desc.getElementDescriptor(descIndex)} - (Preview:${api.preview(rowIndex, ctx.row)})" + (if (note != "") " - ${note}" else ""))
}
rowIndex += 1
return curr
Expand All @@ -137,7 +138,7 @@ class RowDecoder<Session, Row> private constructor(
fun nextRowIndex(note: String = ""): Int {
val curr = rowIndex
if (debugMode) {
println("[RowDecoder] Get Next Row Index ${ctx.columnInfo(rowIndex)?.name} - (Preview:${api.preview(rowIndex, ctx.row)})" + (if (note != "") " - ${note}" else ""))
println("[RowDecoder] Get Next Row Index ${ctx.columnInfoSafe(rowIndex)?.name} - (Preview:${api.preview(rowIndex, ctx.row)})" + (if (note != "") " - ${note}" else ""))
}
rowIndex += 1
return curr
Expand Down Expand Up @@ -390,7 +391,7 @@ class RowDecoder<Session, Row> private constructor(
}

else ->
throw IllegalArgumentException("Unsupported kind: `${desc.kind}` at index: ${index} (info:${ctx.columnInfos?.get(index)})")
throw IllegalArgumentException("Unsupported kind: `${desc.kind}` at (${ctx.startingIndex.description}) index: ${index} (info:${ctx.columnInfos?.get(index)})")
}
}

Expand Down Expand Up @@ -418,7 +419,7 @@ class RowDecoder<Session, Row> private constructor(
element != null -> element
//now: element == null must be true
descriptor.getElementDescriptor(index).isNullable -> null as T
else -> throw IllegalArgumentException("Error at column ${ctx.columnInfos?.get(index)}. Found null element at index ${index} of descriptor ${descriptor.getElementDescriptor(index)} (of ${descriptor}) where null values are not allowed.")
else -> throw IllegalArgumentException("Error at column ${ctx.columnInfos?.get(index)}. Found null element at (${ctx.startingIndex.description}) index ${index} of descriptor ${descriptor.getElementDescriptor(index)} (of ${descriptor}) where null values are not allowed.")
}
}

Expand Down
1 change: 1 addition & 0 deletions controller-r2dbc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ kotlin {
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.8.1")
// R2DBC SPI only (no specific driver)
api("io.r2dbc:r2dbc-spi:1.0.0.RELEASE")
// Need to pull in Postgres driver to use it's Json object for wrapping
compileOnly("org.postgresql:r2dbc-postgresql:1.0.5.RELEASE")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.reactive.asFlow
import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactive.collect

abstract class R2dbcController(
override val encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(),
Expand All @@ -25,12 +24,13 @@ abstract class R2dbcController(
HasTransactionalityR2dbc
{
override fun DefaultOpts(): R2dbcExecutionOptions = R2dbcExecutionOptions.Default()
override fun dbTypeIsRelevant(): Boolean = false

override val encodingApi: R2dbcSqlEncoding =
object: JavaSqlEncoding<Connection, Statement, Row>,
BasicEncoding<Connection, Statement, Row> by R2dbcBasicEncoding,
JavaTimeEncoding<Connection, Statement, Row> by R2dbcTimeEncoding,
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncoding {}
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncodingNative {}

override val allEncoders: Set<SqlEncoder<Connection, Statement, out Any>> by lazy { encodingApi.computeEncoders() + encodingConfig.additionalEncoders }
override val allDecoders: Set<SqlDecoder<Connection, Row, out Any>> by lazy { encodingApi.computeDecoders() + encodingConfig.additionalDecoders }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import io.exoquery.controller.BasicEncoding
import io.exoquery.controller.JavaSqlEncoding
import io.exoquery.controller.JavaTimeEncoding
import io.exoquery.controller.JavaUuidEncoding
import io.exoquery.controller.SqlDecoder
import io.exoquery.controller.SqlEncoder
import io.exoquery.controller.StartingIndex
import io.r2dbc.spi.Connection
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.Row
Expand All @@ -27,53 +26,78 @@ object R2dbcControllers {
object: JavaSqlEncoding<Connection, Statement, Row>,
BasicEncoding<Connection, Statement, Row> by R2dbcBasicEncoding,
JavaTimeEncoding<Connection, Statement, Row> by R2dbcTimeEncoding,
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncoding {}

override protected fun changePlaceholders(sql: String): String {
// Postgres R2DBC uses $1, $2... for placeholders.
// Most other R2DBC drivers (e.g. MSSQL) use '?', so do not rewrite for them.
val sb = StringBuilder()
var paramIndex = 1
var i = 0
while (i < sql.length) {
val c = sql[i]
if (c == '?') {
sb.append('$').append(paramIndex)
paramIndex++
i++
} else {
sb.append(c)
i++
}
}
return sb.toString()
}
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncodingNative {}

override protected fun changePlaceholders(sql: String): String =
changePlaceholdersIn(sql) { index -> "$${index + 1}" }
}

class SqlServer(
encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(),
override val connectionFactory: ConnectionFactory
): R2dbcController(encodingConfig,connectionFactory) {
override protected fun changePlaceholders(sql: String): String {
// MSSQL R2DBC uses @1, @2... for placeholders.
// Most other R2DBC drivers (e.g. MSSQL) use '?', so do not rewrite for them.
val sb = StringBuilder()
var paramIndex = 0
var i = 0
while (i < sql.length) {
val c = sql[i]
if (c == '?') {
// Params are named like @Param0, @Param1, ... parameter
// binding is indexed based. SqlServer R2DBC supports this.
sb.append("@Param${paramIndex}")
paramIndex++
i++
} else {
sb.append(c)
i++
}
}
return sb.toString()
}

override val encodingApi: R2dbcSqlEncoding =
object: JavaSqlEncoding<Connection, Statement, Row>,
BasicEncoding<Connection, Statement, Row> by R2dbcBasicEncoding,
JavaTimeEncoding<Connection, Statement, Row> by R2dbcTimeEncodingSqlServer,
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncodingString {}

/** Change the names of the variable params so they can be used by the SQL Server R2DBC driver
* The SQL Server R2DBC driver supports named-parameter binding i.e. row.bind("@firstName", value)
* as well as positional binding i.e. row.bind(0, value). When positional binding is done, the names
* of ther parameters in the SQL string are ignored. Since we are using positional binding,
* we can use any names we want so we want to choose names that are user friendly to debug.
* Therefore we choose @ParamX where X is the index-kind that the context actually uses.
*/
override protected fun changePlaceholders(sql: String): String =
changePlaceholdersIn(sql) { index -> "@Param${index + startingStatementIndex.value}" }
}

class Mysql(
encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(),
override val connectionFactory: ConnectionFactory
): R2dbcController(encodingConfig, connectionFactory) {

override val encodingApi: R2dbcSqlEncoding =
object: JavaSqlEncoding<Connection, Statement, Row>,
BasicEncoding<Connection, Statement, Row> by R2dbcBasicEncoding,
JavaTimeEncoding<Connection, Statement, Row> by R2dbcTimeEncoding,
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncodingString {}

// MySQL R2DBC uses '?' positional parameters, so no change
override fun changePlaceholders(sql: String): String = sql
}

class H2(
encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(),
override val connectionFactory: ConnectionFactory
): R2dbcController(encodingConfig, connectionFactory) {

override val startingResultRowIndex: StartingIndex get() = StartingIndex.Zero

override val encodingApi: R2dbcSqlEncoding =
object: JavaSqlEncoding<Connection, Statement, Row>,
BasicEncoding<Connection, Statement, Row> by R2dbcBasicEncodingH2, // Need to override Int encoders with Long
JavaTimeEncoding<Connection, Statement, Row> by R2dbcTimeEncodingH2,
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncodingNative {}

override protected fun changePlaceholders(sql: String): String =
changePlaceholdersIn(sql) { index -> "$${index + 1}" }
}

class Oracle(
encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(),
override val connectionFactory: ConnectionFactory
): R2dbcController(encodingConfig, connectionFactory) {

override val encodingApi: R2dbcSqlEncoding =
object: JavaSqlEncoding<Connection, Statement, Row>,
BasicEncoding<Connection, Statement, Row> by R2dbcBasicEncodingOracle,
JavaTimeEncoding<Connection, Statement, Row> by R2dbcTimeEncodingOracle,
JavaUuidEncoding<Connection, Statement, Row> by R2dbcUuidEncodingString {}

override protected fun changePlaceholders(sql: String): String =
changePlaceholdersIn(sql) { index -> ":${index + 1}" }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ import io.exoquery.controller.DecoderAny
import io.exoquery.controller.SqlDecoder
import io.r2dbc.spi.Connection
import io.r2dbc.spi.Row
import kotlinx.datetime.toKotlinInstant
import kotlinx.datetime.toKotlinLocalDate
import kotlinx.datetime.toKotlinLocalDateTime
import kotlinx.datetime.toKotlinLocalTime
import java.time.*
import java.util.*
import kotlin.reflect.KClass

class R2dbcDecoderAny<T: Any>(
Expand Down Expand Up @@ -51,7 +45,7 @@ object R2dbcDecoders {
R2dbcTimeEncoding.JOffsetTimeDecoder,
R2dbcTimeEncoding.JOffsetDateTimeDecoder,
R2dbcTimeEncoding.JDateDecoder,
R2dbcUuidEncoding.JUuidDecoder,
R2dbcUuidEncodingNative.JUuidDecoder,

R2dbcAdditionalEncoding.BigDecimalDecoder
)
Expand Down
Loading