Module aws_lambda_powertools.event_handler.middlewares.openapi_validation

Classes

class OpenAPIValidationMiddleware (validation_serializer: Callable[[Any], str] | None = None)

OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It should not be used directly, but rather through the enable_validation parameter of the ApiGatewayResolver.

Examples

from pydantic import BaseModel

from aws_lambda_powertools.event_handler.api_gateway import (
    APIGatewayRestResolver,
)

class Todo(BaseModel):
  name: str

app = APIGatewayRestResolver(enable_validation=True)

@app.get("/todos")
def get_todos(): list[Todo]:
  return [Todo(name="hello world")]

Initialize the OpenAPIValidationMiddleware.

Parameters

validation_serializer : Callable, optional
Optional serializer to use when serializing the response for validation. Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
Expand source code
class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
    """
    OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
    Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
    should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.

    Examples
    --------

    ```python
    from pydantic import BaseModel

    from aws_lambda_powertools.event_handler.api_gateway import (
        APIGatewayRestResolver,
    )

    class Todo(BaseModel):
      name: str

    app = APIGatewayRestResolver(enable_validation=True)

    @app.get("/todos")
    def get_todos(): list[Todo]:
      return [Todo(name="hello world")]
    ```
    """

    def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
        """
        Initialize the OpenAPIValidationMiddleware.

        Parameters
        ----------
        validation_serializer : Callable, optional
            Optional serializer to use when serializing the response for validation.
            Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
        """
        self._validation_serializer = validation_serializer

    def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
        logger.debug("OpenAPIValidationMiddleware handler")

        route: Route = app.context["_route"]

        values: dict[str, Any] = {}
        errors: list[Any] = []

        # Process path values, which can be found on the route_args
        path_values, path_errors = _request_params_to_args(
            route.dependant.path_params,
            app.context["_route_args"],
        )

        # Normalize query values before validate this
        query_string = _normalize_multi_query_string_with_param(
            app.current_event.resolved_query_string_parameters,
            route.dependant.query_params,
        )

        # Process query values
        query_values, query_errors = _request_params_to_args(
            route.dependant.query_params,
            query_string,
        )

        # Normalize header values before validate this
        headers = _normalize_multi_header_values_with_param(
            app.current_event.resolved_headers_field,
            route.dependant.header_params,
        )

        # Process header values
        header_values, header_errors = _request_params_to_args(
            route.dependant.header_params,
            headers,
        )

        values.update(path_values)
        values.update(query_values)
        values.update(header_values)
        errors += path_errors + query_errors + header_errors

        # Process the request body, if it exists
        if route.dependant.body_params:
            (body_values, body_errors) = _request_body_to_args(
                required_params=route.dependant.body_params,
                received_body=self._get_body(app),
            )
            values.update(body_values)
            errors.extend(body_errors)

        if errors:
            # Raise the validation errors
            raise RequestValidationError(_normalize_errors(errors))
        else:
            # Re-write the route_args with the validated values, and call the next middleware
            app.context["_route_args"] = values

            # Call the handler by calling the next middleware
            response = next_middleware(app)

            # Process the response
            return self._handle_response(route=route, response=response)

    def _handle_response(self, *, route: Route, response: Response):
        # Process the response body if it exists
        if response.body:
            # Validate and serialize the response, if it's JSON
            if response.is_json():
                response.body = self._serialize_response(
                    field=route.dependant.return_param,
                    response_content=response.body,
                )

        return response

    def _serialize_response(
        self,
        *,
        field: ModelField | None = None,
        response_content: Any,
        include: IncEx | None = None,
        exclude: IncEx | None = None,
        by_alias: bool = True,
        exclude_unset: bool = False,
        exclude_defaults: bool = False,
        exclude_none: bool = False,
    ) -> Any:
        """
        Serialize the response content according to the field type.
        """
        if field:
            errors: list[dict[str, Any]] = []
            # MAINTENANCE: remove this when we drop pydantic v1
            if not hasattr(field, "serializable"):
                response_content = self._prepare_response_content(
                    response_content,
                    exclude_unset=exclude_unset,
                    exclude_defaults=exclude_defaults,
                    exclude_none=exclude_none,
                )

            value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
            if errors:
                raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

            if hasattr(field, "serialize"):
                return field.serialize(
                    value,
                    include=include,
                    exclude=exclude,
                    by_alias=by_alias,
                    exclude_unset=exclude_unset,
                    exclude_defaults=exclude_defaults,
                    exclude_none=exclude_none,
                )

            return jsonable_encoder(
                value,
                include=include,
                exclude=exclude,
                by_alias=by_alias,
                exclude_unset=exclude_unset,
                exclude_defaults=exclude_defaults,
                exclude_none=exclude_none,
                custom_serializer=self._validation_serializer,
            )
        else:
            # Just serialize the response content returned from the handler
            return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

    def _prepare_response_content(
        self,
        res: Any,
        *,
        exclude_unset: bool,
        exclude_defaults: bool = False,
        exclude_none: bool = False,
    ) -> Any:
        """
        Prepares the response content for serialization.
        """
        if isinstance(res, BaseModel):
            return _model_dump(
                res,
                by_alias=True,
                exclude_unset=exclude_unset,
                exclude_defaults=exclude_defaults,
                exclude_none=exclude_none,
            )
        elif isinstance(res, list):
            return [
                self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
                for item in res
            ]
        elif isinstance(res, dict):
            return {
                k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
                for k, v in res.items()
            }
        elif dataclasses.is_dataclass(res):
            return dataclasses.asdict(res)  # type: ignore[arg-type]
        return res

    def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
        """
        Get the request body from the event, and parse it as JSON.
        """

        content_type = app.current_event.headers.get("content-type")
        if not content_type or content_type.strip().startswith("application/json"):
            try:
                return app.current_event.json_body
            except json.JSONDecodeError as e:
                raise RequestValidationError(
                    [
                        {
                            "type": "json_invalid",
                            "loc": ("body", e.pos),
                            "msg": "JSON decode error",
                            "input": {},
                            "ctx": {"error": e.msg},
                        },
                    ],
                    body=e.doc,
                ) from e
        else:
            raise NotImplementedError("Only JSON body is supported")

Ancestors

Inherited members