Module aws_lambda_powertools.event_handler.middlewares.openapi_validation
Classes
class OpenAPIValidationMiddleware (validation_serializer: Callable[[Any], str] | None = None)
-
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It should not be used directly, but rather through the
enable_validation
parameter of theApiGatewayResolver
.Examples
from pydantic import BaseModel from aws_lambda_powertools.event_handler.api_gateway import ( APIGatewayRestResolver, ) class Todo(BaseModel): name: str app = APIGatewayRestResolver(enable_validation=True) @app.get("/todos") def get_todos(): list[Todo]: return [Todo(name="hello world")]
Initialize the OpenAPIValidationMiddleware.
Parameters
validation_serializer
:Callable
, optional- Optional serializer to use when serializing the response for validation. Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
Expand source code
class OpenAPIValidationMiddleware(BaseMiddlewareHandler): """ OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`. Examples -------- ```python from pydantic import BaseModel from aws_lambda_powertools.event_handler.api_gateway import ( APIGatewayRestResolver, ) class Todo(BaseModel): name: str app = APIGatewayRestResolver(enable_validation=True) @app.get("/todos") def get_todos(): list[Todo]: return [Todo(name="hello world")] ``` """ def __init__(self, validation_serializer: Callable[[Any], str] | None = None): """ Initialize the OpenAPIValidationMiddleware. Parameters ---------- validation_serializer : Callable, optional Optional serializer to use when serializing the response for validation. Use it when you have a custom type that cannot be serialized by the default jsonable_encoder. """ self._validation_serializer = validation_serializer def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") route: Route = app.context["_route"] values: dict[str, Any] = {} errors: list[Any] = [] # Process path values, which can be found on the route_args path_values, path_errors = _request_params_to_args( route.dependant.path_params, app.context["_route_args"], ) # Normalize query values before validate this query_string = _normalize_multi_query_string_with_param( app.current_event.resolved_query_string_parameters, route.dependant.query_params, ) # Process query values query_values, query_errors = _request_params_to_args( route.dependant.query_params, query_string, ) # Normalize header values before validate this headers = _normalize_multi_header_values_with_param( app.current_event.resolved_headers_field, route.dependant.header_params, ) # Process header values header_values, header_errors = _request_params_to_args( route.dependant.header_params, headers, ) values.update(path_values) values.update(query_values) values.update(header_values) errors += path_errors + query_errors + header_errors # Process the request body, if it exists if route.dependant.body_params: (body_values, body_errors) = _request_body_to_args( required_params=route.dependant.body_params, received_body=self._get_body(app), ) values.update(body_values) errors.extend(body_errors) if errors: # Raise the validation errors raise RequestValidationError(_normalize_errors(errors)) else: # Re-write the route_args with the validated values, and call the next middleware app.context["_route_args"] = values # Call the handler by calling the next middleware response = next_middleware(app) # Process the response return self._handle_response(route=route, response=response) def _handle_response(self, *, route: Route, response: Response): # Process the response body if it exists if response.body: # Validate and serialize the response, if it's JSON if response.is_json(): response.body = self._serialize_response( field=route.dependant.return_param, response_content=response.body, ) return response def _serialize_response( self, *, field: ModelField | None = None, response_content: Any, include: IncEx | None = None, exclude: IncEx | None = None, by_alias: bool = True, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: """ Serialize the response content according to the field type. """ if field: errors: list[dict[str, Any]] = [] # MAINTENANCE: remove this when we drop pydantic v1 if not hasattr(field, "serializable"): response_content = self._prepare_response_content( response_content, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) if errors: raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) if hasattr(field, "serialize"): return field.serialize( value, include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) return jsonable_encoder( value, include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_serializer=self._validation_serializer, ) else: # Just serialize the response content returned from the handler return jsonable_encoder(response_content, custom_serializer=self._validation_serializer) def _prepare_response_content( self, res: Any, *, exclude_unset: bool, exclude_defaults: bool = False, exclude_none: bool = False, ) -> Any: """ Prepares the response content for serialization. """ if isinstance(res, BaseModel): return _model_dump( res, by_alias=True, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) elif isinstance(res, list): return [ self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) for item in res ] elif isinstance(res, dict): return { k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) for k, v in res.items() } elif dataclasses.is_dataclass(res): return dataclasses.asdict(res) # type: ignore[arg-type] return res def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: """ Get the request body from the event, and parse it as JSON. """ content_type = app.current_event.headers.get("content-type") if not content_type or content_type.strip().startswith("application/json"): try: return app.current_event.json_body except json.JSONDecodeError as e: raise RequestValidationError( [ { "type": "json_invalid", "loc": ("body", e.pos), "msg": "JSON decode error", "input": {}, "ctx": {"error": e.msg}, }, ], body=e.doc, ) from e else: raise NotImplementedError("Only JSON body is supported")
Ancestors
- BaseMiddlewareHandler
- typing.Generic
- abc.ABC
Inherited members