diff --git a/build.sbt b/build.sbt index 1d1902f..d818d7e 100644 --- a/build.sbt +++ b/build.sbt @@ -1,12 +1,15 @@ import com.trueaccord.scalapb.compiler.Version.{grpcJavaVersion, scalapbVersion} organization in ThisBuild := "beyondthelines" -version in ThisBuild := "0.0.8" +version in ThisBuild := "0.0.10-SNAPSHOT" licenses in ThisBuild := ("MIT", url("http://opensource.org/licenses/MIT")) :: Nil bintrayOrganization in ThisBuild := Some("beyondthelines") bintrayPackageLabels in ThisBuild := Seq("scala", "protobuf", "grpc") scalaVersion in ThisBuild := "2.12.4" +val googleapisVersion = "0.0.3" +val scalatestVersion = "3.0.4" + lazy val runtime = (project in file("runtime")) .settings( crossScalaVersions := Seq("2.12.4", "2.11.11"), @@ -16,8 +19,9 @@ lazy val runtime = (project in file("runtime")) "com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion, "com.trueaccord.scalapb" %% "scalapb-json4s" % "0.3.3", "io.grpc" % "grpc-netty" % grpcJavaVersion, + "org.scalatest" %% "scalatest" % scalatestVersion % Test, "org.webjars" % "swagger-ui" % "3.5.0", - "com.google.api.grpc" % "googleapis-common-protos" % "0.0.3" % "protobuf" + "com.google.api.grpc" % "googleapis-common-protos" % googleapisVersion % "protobuf" ), PB.protoSources in Compile += target.value / "protobuf_external", includeFilter in PB.generate := new SimpleFilter( @@ -35,9 +39,10 @@ lazy val generator = (project in file("generator")) crossScalaVersions := Seq("2.12.4", "2.10.6"), name := "GrpcGatewayGenerator", libraryDependencies ++= Seq( - "com.trueaccord.scalapb" %% "compilerplugin" % scalapbVersion, - "com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion, - "com.google.api.grpc" % "googleapis-common-protos" % "0.0.3" % "protobuf" + "com.trueaccord.scalapb" %% "compilerplugin" % scalapbVersion, + "com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % scalapbVersion, + "com.google.api.grpc" % "googleapis-common-protos" % googleapisVersion % "protobuf", + "org.scalatest" %% "scalatest" % scalatestVersion % Test ), PB.protoSources in Compile += target.value / "protobuf_external", includeFilter in PB.generate := new SimpleFilter( diff --git a/generator/src/main/scala/grpcgateway/generators/GatewayGenerator.scala b/generator/src/main/scala/grpcgateway/generators/GatewayGenerator.scala index d01ac13..40f1026 100644 --- a/generator/src/main/scala/grpcgateway/generators/GatewayGenerator.scala +++ b/generator/src/main/scala/grpcgateway/generators/GatewayGenerator.scala @@ -1,6 +1,6 @@ package grpcgateway.generators -import com.google.api.AnnotationsProto +import com.google.api.{AnnotationsProto, HttpRule} import com.google.api.HttpRule.PatternCase import com.google.protobuf.Descriptors.FieldDescriptor.JavaType import com.google.protobuf.Descriptors._ @@ -10,10 +10,11 @@ import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo import com.trueaccord.scalapb.compiler.{DescriptorPimps, FunctionalPrinter} import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scalapbshade.v0_6_7.com.trueaccord.scalapb.Scalapb object GatewayGenerator extends protocbridge.ProtocCodeGenerator with DescriptorPimps { - override val params = com.trueaccord.scalapb.compiler.GeneratorParams() override def run(requestBytes: Array[Byte]): Array[Byte] = { @@ -23,7 +24,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor val b = CodeGeneratorResponse.newBuilder val request = CodeGeneratorRequest.parseFrom(requestBytes, registry) - val fileDescByName: Map[String, FileDescriptor] = request.getProtoFileList.asScala.foldLeft[Map[String, FileDescriptor]](Map.empty) { case (acc, fp) => @@ -50,16 +50,16 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor .add( "import _root_.com.trueaccord.scalapb.GeneratedMessage", "import _root_.com.trueaccord.scalapb.json.JsonFormat", - "import _root_.grpcgateway.handlers._", - "import _root_.io.grpc._", - "import _root_.io.netty.handler.codec.http.{HttpMethod, QueryStringDecoder}" + "import _root_.grpcgateway.handlers.GrpcGatewayHandler", + "import _root_.grpcgateway.handlers.jsonException2GatewayExceptionPF", + "import _root_.io.grpc.ManagedChannel", + "import _root_.io.netty.handler.codec.http.HttpMethod" ) .newline .add( - "import scala.collection.JavaConverters._", "import scala.concurrent.{ExecutionContext, Future}", - "import com.trueaccord.scalapb.json.JsonFormatException", - "import scala.util._" + "import grpcgateway.util.{RestfulUrl, UrlTemplate}", + "import scala.util.Try" ) .newline .print(fileDesc.getServices.asScala) { case (p, s) => generateService(s)(p) } @@ -73,13 +73,14 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor _.add(s"class ${service.getName}Handler(channel: ManagedChannel)(implicit ec: ExecutionContext)").indent .add( "extends GrpcGatewayHandler(channel)(ec) {", + "// a function that takes a RestfulUrl and produces a function that takes a request body and returns a response message", + "type RestfulHandler = RestfulUrl => (String) => Future[GeneratedMessage]", + "", s"""override val name: String = "${service.getName}"""", s"private val stub = ${service.getName}Grpc.stub(channel)" ) .newline - .call(generateSupportsCall(service)) - .newline - .call(generateUnaryCall(service)) + .call(generateCallSeqsByVerb(getUnaryCallsWithHttpExtension(service))) .outdent .add("}") .newline @@ -91,61 +92,12 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor } } - private def generateUnaryCall(service: ServiceDescriptor): PrinterEndo = { printer => - val methods = getUnaryCallsWithHttpExtension(service) - printer - .add(s"override def unaryCall(method: HttpMethod, uri: String, body: String): Future[GeneratedMessage] = {") - .indent - .add( - "val queryString = new QueryStringDecoder(uri)", - "(method.name, queryString.path) match {" - ) - .indent - .print(methods) { case (p, m) => generateMethodHandlerCase(m)(p) } - .add("case (methodName, path) => ") - .addIndented("""Future.failed(InvalidArgument(s"No route defined for $methodName($path)"))""") - .outdent - .add("}") - .outdent - .add("}") - } - - private def generateSupportsCall(service: ServiceDescriptor): PrinterEndo = { printer => - val methods = getUnaryCallsWithHttpExtension(service) - printer - .add(s"override def supportsCall(method: HttpMethod, uri: String): Boolean = {") - .indent - .add( - "val queryString = new QueryStringDecoder(uri)", - "(method.name, queryString.path) match {" - ) - .indent - .print(methods) { case (p, m) => generateMethodCase(m)(p) } - .add("case _ => false") - .outdent - .add("}") - .outdent - .add("}") - } - - private def generateMethodCase(method: MethodDescriptor): PrinterEndo = { printer => - val http = method.getOptions.getExtension(AnnotationsProto.http) - http.getPatternCase match { - case PatternCase.GET => printer.add(s"""case ("GET", "${http.getGet}") => true""") - case PatternCase.POST => printer.add(s"""case ("POST", "${http.getPost}") => true""") - case PatternCase.PUT => printer.add(s"""case ("PUT", "${http.getPut}") => true""") - case PatternCase.DELETE => printer.add(s"""case ("DELETE", "${http.getDelete}") => true""") - case _ => printer - } - } - private def generateMethodHandlerCase(method: MethodDescriptor): PrinterEndo = { printer => val http = method.getOptions.getExtension(AnnotationsProto.http) val methodName = method.getName.charAt(0).toLower + method.getName.substring(1) http.getPatternCase match { case PatternCase.GET => printer - .add(s"""case ("GET", "${http.getGet}") => """) .indent .add("val input = Try {") .indent @@ -156,7 +108,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor .outdent case PatternCase.POST => printer - .add(s"""case ("POST", "${http.getPost}") => """) .add("for {") .addIndented( s"""msg <- Future.fromTry(Try(JsonFormat.fromJsonString[${method.getInputType.getName}](body)).recoverWith(jsonException2GatewayExceptionPF))""", @@ -165,7 +116,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor .add("} yield res") case PatternCase.PUT => printer - .add(s"""case ("PUT", "${http.getPut}") => """) .add("for {") .addIndented( s"""msg <- Future.fromTry(Try(JsonFormat.fromJsonString[${method.getInputType.getName}](body)).recoverWith(jsonException2GatewayExceptionPF))""", @@ -174,7 +124,6 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor .add("} yield res") case PatternCase.DELETE => printer - .add(s"""case ("DELETE", "${http.getDelete}") => """) .indent .add("val input = Try {") .indent @@ -203,37 +152,37 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor case JavaType.ENUM => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""${f.getName}.valueOf(queryString.parameters().get("$prefix${f.getJsonName}").asScala.head)""" + s"""${f.getName}.valueOf(url.parameter("$prefix${f.getJsonName}"))""" ) case JavaType.BOOLEAN => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toBoolean""" + s"""url.parameter("$prefix${f.getJsonName}").toBoolean""" ) case JavaType.DOUBLE => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toDouble""" + s"""url.parameter("$prefix${f.getJsonName}").toDouble""" ) case JavaType.FLOAT => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toFloat""" + s"""url.parameter("$prefix${f.getJsonName}").toFloat""" ) case JavaType.INT => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toInt""" + s"""url.parameter("$prefix${f.getJsonName}").toInt""" ) case JavaType.LONG => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head.toLong""" + s"""url.parameter("$prefix${f.getJsonName}").toLong""" ) case JavaType.STRING => p.add(s"val ${inputName(f, prefix)} = ") .addIndented( - s"""queryString.parameters().get("$prefix${f.getJsonName}").asScala.head""" + s"""url.parameter("$prefix${f.getJsonName}")""" ) case jt => throw new Exception(s"Unknown java type: $jt") } @@ -246,4 +195,100 @@ object GatewayGenerator extends protocbridge.ProtocCodeGenerator with Descriptor name.charAt(0).toLower + name.substring(1) } + private def generateCallSeqsByVerb(descritors: mutable.Seq[MethodDescriptor]): PrinterEndo = { printer => + val verbToMethods: mutable.Map[PatternCase, Seq[RestfulMethod]] = MethodDescriptors.methodsByVerb(descritors) + printer + .call(generateCallSeqsByVerb(verbToMethods)) + .call(generateSupportsCall(verbToMethods.keySet)) + } + + private def generateCallSeqsByVerb(verbToMethods: mutable.Map[PatternCase, Seq[RestfulMethod]]): PrinterEndo = { printer => + printer. + print(verbToMethods) { case (p, (pattern,methods)) => generateCallSeq(pattern, methods)(p) } + } + + private def generateCallSeq(verb: PatternCase, methods: Seq[RestfulMethod]): PrinterEndo = { printer => + printer + .add(s"private val ${verb.name().toLowerCase}Calls: Seq[(UrlTemplate, RestfulHandler)] = Seq(") + .indent + .print(methods) { case (p, method) => generateCall(method)(p) } + .outdent + .add(")") // Seq + .newline + } + + private def generateCall(method: RestfulMethod): PrinterEndo = { printer => + printer + .add("(") // pair + .add( + s"""UrlTemplate("${method.urlTemplate}"),""", + "(url: RestfulUrl) => (body: String) => {" // function + ) + .indent + .call(generateMethodHandlerCase(method.method)) + .outdent + .add("}") // function + .add("),") // pair + } + + private def generateSupportsCall(verbs: collection.Set[PatternCase]): PrinterEndo = { printer => + printer + .add(s"override def supportsCall(method: HttpMethod, uri: String): Option[UnaryCall] = {") + .indent + .add("method.name match {") + .indent + .print(verbs) { case (p, verb) => generateVerbCase(verb)(p) } + .add("case _ => None") + .outdent + .add("}") // match + .outdent + .add("}") // def + } + + private def generateVerbCase(verb: PatternCase): PrinterEndo = { printer => + printer + .add(s"""case "${verb.name().toUpperCase}" =>""") + .indent + .add(s"for ((restful, handler) <- ${verb.name().toLowerCase}Calls) {") + .indent + .add("val mayBe = restful.matchUri(uri).map((url: RestfulUrl) => handler(url))") + .add("if (mayBe.isDefined) {") + .indent + .add("return mayBe") + .outdent + .add("}") //if + .outdent + .add("}") // for + .newline + .add("None") // def + .newline + .outdent // case body + } + +} + +private case class RestfulMethod(urlTemplate: String, method: MethodDescriptor) + +private object MethodDescriptors { + def methodsByVerb(descriptors: mutable.Seq[MethodDescriptor]) : mutable.Map[PatternCase, Seq[RestfulMethod]] = { + val map = mutable.Map[PatternCase, ArrayBuffer[RestfulMethod]]() + + descriptors.foreach((md: MethodDescriptor) => { + val http = md.getOptions.getExtension(AnnotationsProto.http) + val seq = map.getOrElseUpdate(http.getPatternCase, ArrayBuffer()) + seq += RestfulMethod(urlTemplate(http), md) + }) + + map.asInstanceOf[mutable.Map[PatternCase, Seq[RestfulMethod]]] // todo how to do it with "A <:" ? + } + + private def urlTemplate(http: HttpRule): String = { + http.getPatternCase match { + case PatternCase.GET => http.getGet + case PatternCase.POST => http.getPost + case PatternCase.PUT => http.getPut + case PatternCase.DELETE => http.getDelete + case _ => throw new IllegalArgumentException(s"Unsupported pattern: ${http.getPatternCase}") + } + } } diff --git a/generator/src/test/resources/objectstore_proto.bin b/generator/src/test/resources/objectstore_proto.bin new file mode 100644 index 0000000..c556f81 Binary files /dev/null and b/generator/src/test/resources/objectstore_proto.bin differ diff --git a/generator/src/test/scala/grpcgateway/generators/GatewayGeneratorTest.scala b/generator/src/test/scala/grpcgateway/generators/GatewayGeneratorTest.scala new file mode 100644 index 0000000..e3102f9 --- /dev/null +++ b/generator/src/test/scala/grpcgateway/generators/GatewayGeneratorTest.scala @@ -0,0 +1,24 @@ +package grpcgateway.generators + +import java.nio.file.{Files, Paths} + +import com.google.protobuf.compiler.PluginProtos.{CodeGeneratorRequest, CodeGeneratorResponse} +import org.scalatest.{Assertions, FlatSpec} +import protocbridge.frontend.PluginFrontend + +class GatewayGeneratorTest extends FlatSpec with Assertions { + private val DIR = "generator/target/scala-2.12/test-classes/" + + it should "generate" in { + val requestProtoStream = Files.newInputStream(Paths.get(DIR + "objectstore_proto.bin")) + val request = CodeGeneratorRequest.parseFrom(requestProtoStream) + + val responseBytes: Array[Byte] = PluginFrontend.runWithBytes(GatewayGenerator, request.toByteArray) + val generatedResponse = CodeGeneratorResponse.parseFrom(responseBytes) + + for (i <- 0 until generatedResponse.getFileCount) { + val file = generatedResponse.getFile(i) + Files.write(Paths.get(DIR + file.getName.substring(file.getName.lastIndexOf("/") + 1)), file.getContent.getBytes()) + } + } +} diff --git a/runtime/src/main/scala/grpcgateway/handlers/GrpcGatewayHandler.scala b/runtime/src/main/scala/grpcgateway/handlers/GrpcGatewayHandler.scala index 3598dd4..c0af37b 100644 --- a/runtime/src/main/scala/grpcgateway/handlers/GrpcGatewayHandler.scala +++ b/runtime/src/main/scala/grpcgateway/handlers/GrpcGatewayHandler.scala @@ -6,32 +6,40 @@ import com.trueaccord.scalapb.GeneratedMessage import com.trueaccord.scalapb.json.JsonFormat import io.grpc.ManagedChannel import io.netty.channel.ChannelHandler.Sharable -import io.netty.channel.{ ChannelFutureListener, ChannelHandlerContext, ChannelInboundHandlerAdapter } +import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, ChannelInboundHandlerAdapter} import io.netty.handler.codec.http._ -import scala.concurrent.{ ExecutionContext, Future } +import scala.concurrent.{ExecutionContext, Future} @Sharable abstract class GrpcGatewayHandler(channel: ManagedChannel)(implicit ec: ExecutionContext) extends ChannelInboundHandlerAdapter { + /** a function that takes a request body and returns a response message */ + type UnaryCall = (String) => Future[GeneratedMessage] def name: String def shutdown(): Unit = if (!channel.isShutdown) channel.shutdown() - def supportsCall(method: HttpMethod, uri: String): Boolean - def unaryCall(method: HttpMethod, uri: String, body: String): Future[GeneratedMessage] + /** + * @param method HTTP verb + * @param uri request path + * @return response message + */ + def supportsCall(method: HttpMethod, uri: String): Option[UnaryCall] override def channelRead(ctx: ChannelHandlerContext, msg: scala.Any): Unit = { msg match { case req: FullHttpRequest => - if (supportsCall(req.method(), req.uri())) { + val mayBeCall: Option[UnaryCall] = supportsCall(req.method(), req.uri()) + if (mayBeCall.isDefined) { + val unaryCall = mayBeCall.get val body = req.content().toString(StandardCharsets.UTF_8) - unaryCall(req.method(), req.uri(), body) + unaryCall(body) .map(JsonFormat.toJsonString) .map(json => { buildFullHttpResponse( @@ -43,24 +51,25 @@ abstract class GrpcGatewayHandler(channel: ManagedChannel)(implicit ec: Executio }) .recover({ case err => - val (body, status) = err match { - case e: GatewayException => e.details -> GRPC_HTTP_CODE_MAP.getOrElse(e.code, HttpResponseStatus.INTERNAL_SERVER_ERROR) - case _ => "Internal error" -> HttpResponseStatus.INTERNAL_SERVER_ERROR - } + val (body, status) = err match { + case e: GatewayException => e.details -> GRPC_HTTP_CODE_MAP.getOrElse(e.code, HttpResponseStatus.INTERNAL_SERVER_ERROR) + case _ => "Internal error" -> HttpResponseStatus.INTERNAL_SERVER_ERROR + } - buildFullHttpResponse( - requestMsg = req, - responseBody = body, - responseStatus = status, - responseContentType = "application/text" - ) - }).foreach(resp => { - ctx.writeAndFlush(resp).addListener(ChannelFutureListener.CLOSE) - }) + buildFullHttpResponse( + requestMsg = req, + responseBody = body, + responseStatus = status, + responseContentType = "application/text" + ) + }).foreach(resp => { + ctx.writeAndFlush(resp).addListener(ChannelFutureListener.CLOSE) + }) } else { super.channelRead(ctx, msg) } + case _ => super.channelRead(ctx, msg) } } diff --git a/runtime/src/main/scala/grpcgateway/util/PathMatcher.scala b/runtime/src/main/scala/grpcgateway/util/PathMatcher.scala new file mode 100644 index 0000000..f18c6bb --- /dev/null +++ b/runtime/src/main/scala/grpcgateway/util/PathMatcher.scala @@ -0,0 +1,51 @@ +package grpcgateway.util + +import scala.collection.mutable + +private[util] trait PathMatcher { + /** + * Assume a sequence of matchers was created by sequentially scanning a URL pattern such as "/get/{template}". + * One by one apply all matchers in the original order + * @param str URL path string + * @param from position in the path to start matching from + * @param templateParams (name,value) pairs of URL parameters extracted from named slots + * @return the position in string the next matcher should continue from + */ + def matchString(str: String, from: Int, templateParams: mutable.Map[String, String]) : Int +} + +private[util] final class TextMatcher(prefix: String) extends PathMatcher { + override def matchString(str: String, from: Int, templateParams: mutable.Map[String, String]): Int = { + val to = from + prefix.length + + if ((to <= str.length) && (str.substring(from, to) == prefix)) { + to + } else { + PathMatcher.NO_MATCH + } + } + + override def toString: String = prefix +} + +private[util] final class TemplateMatcher(name: String) extends PathMatcher { + private val PATH_DELIMITER = '/' + + override def matchString(str: String, from: Int, templateParams: mutable.Map[String, String]): Int = { + var index = from + while ((index < str.length) && (str(index) != PATH_DELIMITER)) { + index += 1 + } + + templateParams.put(name, str.substring(from, index)) + + index + } + + override def toString: String = s"[$name]" +} + +object PathMatcher { + /** The "string position" value to return if a string does not match this matcher */ + val NO_MATCH: Int = -1 +} \ No newline at end of file diff --git a/runtime/src/main/scala/grpcgateway/util/PathParser.scala b/runtime/src/main/scala/grpcgateway/util/PathParser.scala new file mode 100644 index 0000000..4984396 --- /dev/null +++ b/runtime/src/main/scala/grpcgateway/util/PathParser.scala @@ -0,0 +1,82 @@ +package grpcgateway.util + +import scala.collection.mutable.ArrayBuffer + +private final class PathParser(path: String) { + import PathParser.{LCURLY, RCURLY} + + private val matchers = ArrayBuffer[PathMatcher]() + private var index = 0 + + private def matchChar(ch: Char): Unit = { + if (path(index) == ch) { + index += 1 + } else { + throw new IllegalArgumentException(s"Unexpected character ${path(index)} at $index in $path") + } + } + + /** Assume the index points to a named slot. Remember the extracted slot name. */ + private def matchTemplate(): TemplateMatcher = { + matchChar(LCURLY) + + val from = index + while ((index < path.length) && (path(index) != RCURLY)) { + if (path(index) == LCURLY) { + throw new IllegalArgumentException(s"Detected curly braces mismatch at ${from-1} and $index in $path") + } + index += 1 + } + val name = path.substring(from, index) + + matchChar(RCURLY) + + new TemplateMatcher(name) + } + + /** + * Assume the index points to somewhere in-between named slots. + * Collect everything up to the next named slot (or the end of the input) + */ + private def matchStaticText(): TextMatcher = { + val from = index + while ((index < path.length) && (path(index) != LCURLY)) { + if (path(index) == RCURLY) { + throw new IllegalArgumentException(s"Detected curly braces mismatch at ${from} and $index in $path") + } + + index += 1 + } + + new TextMatcher(path.substring(from, index)) + } + + private def parse() : Seq[PathMatcher] = { + while (index < path.length) { + path(index) match { + case LCURLY => + val matcher = matchTemplate() + matchers += matcher + + case _ => + val matcher = matchStaticText() + matchers += matcher + } + } + + matchers + } + + override def toString: String = matchers.mkString +} + +private[util] object PathParser { + private val LCURLY = '{' + private val RCURLY = '}' + + /** @return true if the path contains at least one parameter template such as "/{slot}" */ + def hasTemplates(path: String) : Boolean = path.contains(PathParser.LCURLY) + + /** Sequentially scan a URL template string. Split it into segments representing names slots and everything else. */ + def apply(path: String) : Seq[PathMatcher] = new PathParser(path).parse() +} diff --git a/runtime/src/main/scala/grpcgateway/util/RestfulUrl.scala b/runtime/src/main/scala/grpcgateway/util/RestfulUrl.scala new file mode 100644 index 0000000..6eb2826 --- /dev/null +++ b/runtime/src/main/scala/grpcgateway/util/RestfulUrl.scala @@ -0,0 +1,40 @@ +package grpcgateway.util + +import java.util + +import RestfulUrl._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +/** A container of extracted URI properties */ +trait RestfulUrl { + /** + * A uniform way to access URL parameters extracted from named slots (e.g. "/{slot}/") and ordinary parameters (e.g. "?k=v") + * @return named URL parameter extracted from a query uri with a UrlTemplate */ + def parameter(name: String): String +} + +private final class PlainRestfulUrl(parameters: PathParams) extends RestfulUrl { + override def parameter(name: String): String = parameters.get(name).asScala.head +} + +private final class MergedRestfulUrl(templateParams: TemplateParams, pathParams: PathParams) extends RestfulUrl { + override def parameter(name: String): String = { + if (templateParams.contains(name)) { + templateParams(name) + } else if (pathParams.containsKey(name)) { + pathParams.get(name).asScala.head + } else { + null //throw new IllegalArgumentException(s"Property not found: $name") + } + } +} + +private object RestfulUrl { + /** parameters extracted from named slots such as "/{slot}/" */ + type PathParams = util.Map[String, util.List[String]] + + /** ordinary parameters such as "?k=v" */ + type TemplateParams = mutable.Map[String, String] +} diff --git a/runtime/src/main/scala/grpcgateway/util/UrlTemplate.scala b/runtime/src/main/scala/grpcgateway/util/UrlTemplate.scala new file mode 100644 index 0000000..b650976 --- /dev/null +++ b/runtime/src/main/scala/grpcgateway/util/UrlTemplate.scala @@ -0,0 +1,78 @@ +package grpcgateway.util + +import io.netty.handler.codec.http.QueryStringDecoder + +import scala.collection.mutable + +/** A means of parsing request URLs with support for "URL parameter templates" configured in a protobuf descriptor */ +trait UrlTemplate { + /** + * Match incoming URI against this URL template generated from a protobuf RESTful service descriptor + * + * Netty's FullHttpRequest.uri() happens to return a path so we currently assume the "protocol/host" prefix + * is stripped before this method is called. + * + * @return URL properties extracted from the URI if it matched this template + */ + def matchUri(uri: String) : Option[RestfulUrl] +} + +/** Fast path for URL patterns with no templates */ +private final class PlainUrlTemplate(path: String) extends UrlTemplate { + override def matchUri(uri: String): Option[RestfulUrl] = { + //println(s"Matching \'$uri\' to $path") + + val decoder = new QueryStringDecoder(uri) + if (decoder.path() == path) { + Some(new PlainRestfulUrl(decoder.parameters())) + } else { + None + } + } +} + +/** Remember a sequence of matchers to apply (in the same order) to an incoming URI. While matching, optimistically + * collect values at positions corresponding to names slots in the original URL template */ +private final class MatchingUrlTemplate(matchers: Seq[PathMatcher]) extends UrlTemplate { + override def matchUri(uri: String): Option[RestfulUrl] = { + val decoder = new QueryStringDecoder(uri) + val path = decoder.path() + + //println(s"Matching \'$path\' with ${matchers.mkString}") + + var pathIndex = 0 + var matcherIndex = 0 + + val templateParams = mutable.Map[String, String]() + while (pathIndex < path.length) { + val from = pathIndex + + val matcher = matchers(matcherIndex) + pathIndex = matcher.matchString(path, pathIndex, templateParams) + if (pathIndex == PathMatcher.NO_MATCH) { + return None + } + + //println(s"Matched \'${path.substring(from, pathIndex)}\' with ${matcher.toString} remains [${path.substring(pathIndex)}]") + + matcherIndex += 1 + } + + if (matcherIndex == matchers.size) { + Some(new MergedRestfulUrl(templateParams, decoder.parameters())) + } else { + None + } + } +} + +object UrlTemplate { + /** Parse a URL template into a URL matcher. Use a fast path for templates with no named slots. */ + def apply(path: String) : UrlTemplate = { + if (PathParser.hasTemplates(path)) { + new MatchingUrlTemplate(PathParser(path)) + } else { + new PlainUrlTemplate(path) + } + } +} diff --git a/runtime/src/test/scala/grpcgateway/util/UrlTemplateTest.scala b/runtime/src/test/scala/grpcgateway/util/UrlTemplateTest.scala new file mode 100644 index 0000000..edf1654 --- /dev/null +++ b/runtime/src/test/scala/grpcgateway/util/UrlTemplateTest.scala @@ -0,0 +1,103 @@ +package grpcgateway.util + +import org.scalatest.{Assertions, FlatSpec, Matchers} + +class UrlTemplateTest extends FlatSpec with Matchers with Assertions { + private val KEY = "k" + private val VALUE = "v" + private val PARAM1 = "T123" + private val PARAM2 = "Param456" + + it should "preserve default semantics for fixed requests" in { + val template = "/tree/trunk/branch/leaf/get" + + val url = UrlTemplate(template) + val restful = url.matchUri(template).get + assert(restful != null) + + val kvurl = UrlTemplate(template) + val kvrestful = kvurl.matchUri(s"$template?$KEY=$VALUE").get + assert(kvrestful.parameter(KEY) == VALUE) + } + + it should "support multiple URL parameter templates" in { + assertTwoParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/padding/{param}/"), + s"/tree/trunk/branch/leaf/get/$PARAM1/padding/$PARAM2/") + + assertTwoParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/{param}/suffix"), + s"/tree/trunk/branch/leaf/get/$PARAM1/$PARAM2/suffix") + + assertTwoParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/{param}"), + s"/tree/trunk/branch/leaf/get/$PARAM1/$PARAM2") + + assertTwoParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/{param}/"), + s"/tree/trunk/branch/leaf/get/$PARAM1/$PARAM2/") + } + + it should "merge template and ordinary URL parameters" in { + assertMixedParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/padding/{param}/"), + s"/tree/trunk/branch/leaf/get/$PARAM1/padding/$PARAM2/?$KEY=$VALUE") + + assertMixedParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/{param}/suffix"), + s"/tree/trunk/branch/leaf/get/$PARAM1/$PARAM2/suffix?$KEY=$VALUE") + + assertMixedParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/{param}"), + s"/tree/trunk/branch/leaf/get/$PARAM1/$PARAM2?$KEY=$VALUE") + + assertMixedParams( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/{param}/"), + s"/tree/trunk/branch/leaf/get/$PARAM1/$PARAM2/?$KEY=$VALUE") + } + + it should "no match is found in mismatched URI" in { + assertNoMatch( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}/padding/{param}/"), + s"/tree/trunk/branch/leaf/get/$PARAM1/padding") + + assertNoMatch( + UrlTemplate("/tree/trunk/branch/leaf/get/{template}"), + s"/tree/trunk/branch/leaf/get") + + assertNoMatch( + UrlTemplate("/tree/trunk/branch/leaf/get/padding"), + s"/tree/trunk/branch/leaf/get") + } + + an [IllegalArgumentException] should be thrownBy { + assertNoMatch( + UrlTemplate("/tree/trunk/branch/leaf/get/{template/padding/{param}/"), + s"IGNORED") + } + + an [IllegalArgumentException] should be thrownBy { + assertNoMatch( + UrlTemplate("/tree/trunk/branch/leaf/get/template}/padding/{param}/"), + s"IGNORED") + } + + private def assertTwoParams(template: UrlTemplate, uri: String): Unit = { + val restful = template.matchUri(uri).get + assert(restful.parameter("template") == PARAM1) + assert(restful.parameter("param") == PARAM2) + assert(restful.parameter(KEY) == null) + } + + private def assertMixedParams(template: UrlTemplate, uri: String): Unit = { + val restful = template.matchUri(uri).get + assert(restful.parameter("template") == PARAM1) + assert(restful.parameter("param") == PARAM2) + assert(restful.parameter(KEY) == VALUE) + assert(restful.parameter("") == null) + } + + private def assertNoMatch(template: UrlTemplate, uri: String): Unit = { + assert(template.matchUri(uri).isEmpty) + } +} \ No newline at end of file