diff --git a/.github/workflows/erlang.yml b/.github/workflows/erlang.yml index c142f59..6b174f3 100644 --- a/.github/workflows/erlang.yml +++ b/.github/workflows/erlang.yml @@ -34,3 +34,5 @@ jobs: run: ./rebar3 do xref, dialyzer - name: Run eunit run: ./rebar3 as gha do eunit + - name: Check format + run: ./rebar3 fmt --check diff --git a/README.md b/README.md index 4f71c37..2beec6f 100644 --- a/README.md +++ b/README.md @@ -2,40 +2,11 @@ ![Riak API OpenRiak Status](https://github.com/OpenRiak/riak_api/actions/workflows/erlang.yml/badge.svg?branch=openriak-4.0) -This OTP application encapsulates services for presenting Riak's -public-facing interfaces. Currently this means a generic interface for -exposing Protocol Buffers-based services; HTTP services via Webmachine -will be moved here at a later time. +This OTP application encapsulates services for presenting Riak's public-facing interfaces. -## Contributing +There two APIs: -We encourage contributions to `riak_api` from the community. +- An API using protocol buffers, with a codec defined in [riak_pb](https://github.com/OpenRiak/riak_pb), with the handling of messages managed using `riak_kv_pb_*` modules within Riak KV. +- A HTTP REST-based API (code-named Silver Machine), with the handling of requests defined using `riak_kv_ag_*` modules that implement the callbacks defined in the `riak_api_web_handler` behaviour. -1. Fork the [`riak_api`](https://github.com/basho/riak_api) repository - on Github. -2. Clone your fork or add the remote if you already have a clone of - the repository. - - ``` - git clone git@github.com:yourusername/riak_api.git - # or - git remote add mine git@github.com:yourusername/riak_api.git - ``` - -3. Create a topic branch for your change. - - ``` - git checkout -b some-topic-branch - ``` - -4. Make your change and commit. Use a clear and descriptive commit - message, spanning multiple lines if detailed explanation is needed. -5. Push to your fork of the repository and then send a pull-request - through Github. - - ``` - git push mine some-topic-branch - ``` - -6. A Basho engineer or community maintainer will review your patch and - merge it into the main repository or send you feedback. +For further information on using [Sliver Machine see the provided document](/docs/silverMachine.md). diff --git a/docs/silverMachine.md b/docs/silverMachine.md new file mode 100644 index 0000000..3483e8e --- /dev/null +++ b/docs/silverMachine.md @@ -0,0 +1,178 @@ +# Silver Machine + +## Overview + +Silver Machine is a HTTP/REST request handler. It is designed to be simpler and more performant than Webmachine/Mochiweb, with the trade-off that it provides less complete compliance with standards within the framework: + +- "simpler" means reduced volume of code within the framework (less than half), better use of dialyzer specs to clarify safe usage, and a behaviour module with a smaller and fixed number callbacks. +- "performant" means less CPU overhead when handling Riak requests, especially those carrying a large volume of information via HTTP request headers. + +Silver Machine took direct inspiration from the [Elli HTTP server](https://github.com/elli-lib/elli), using it as a source of ideas for improving performance. + +Silver Machine is not intended to be used outside of Riak. It is a framework developed specifically for the Riak use-case, and may have breaking changes within the framework at any time if such a change is required to support efficiency in Riak. + +Using Silver Machine requires three actions: + +- configuration to start [listeners](#listeners); +- the [loading of routes](#adding-routes), a prioritised list of modules that will provide endpoints via the listener; +- the definition of those modules to handle requests, implemented following the `riak_api_web_handler` [behaviour](#the-riak_api_web_handler-behaviour). + +### Listeners + +A listener is started using `riak_api_web_socket:start_link/1`, where the function takes as its argument a list of options: + +```erlang +-type option() :: + {acceptor_pool_start_size, pos_integer()} + | {acceptor_pool_max_size, pos_integer()} + | {ssl, boolean()} + | {ssl_opts, [ssl:tls_server_option()]} + | {ip, inet:ip_address()} + | {port, inet:port_number()} + | {name, server_name()}. +``` + +Within Riak the `riak_api_sup` sueprvisor is used to discover the bindings (IP and Port pairs) from the configuration, and start a listener for each binding. + +In addition to the passed-in options, three further options cna be set using environment variables: + +- `riak_api/web_kernel_buffer` - which will set the TCP `buffer`; +- `riak_api/web_receive_buffer` - which will set the TCP `recbuf`; +- `riak_api/web_send_buffer` - which will set the TCP `sndbuf`. + +If no environment variables are set, then the `recbuf` will be changed from its default setting to `131072`, and this will automatically [change the `buffer` setting](https://github.com/erlang/otp/issues/9355). + +Each listener is a socket (SSL or TCP), with a pool of acceptors. The acceptors will listen on the socket, and when new connections are made the listen results in the connection being managed by an available acceptor. The acceptor will live for the duration of the connection, but only for the duration of the connection. When an acceptor is assigned a connection (at the start), a new acceptor is started and added to the pool to replace the busy acceptor. There should always be a pool of acceptors ready; however due to the potential timing delays in the assignment of connections to acceptors the `backlog` on the socket (i.e. the backlog of unhandled connections) is configured to 128 (normal OTP default is 5) to avoid unnecessary connection resets. + +The acceptor pool maximum and starting size is defined at startup via the passed in options. If no such options are passed for that listener the defaults are taken from environment variables `riak_api/web_acceptor_pool_start_size` and `riak_api/web_acceptor_pool_max_size`. + +### Adding routes + +There are multiple routing tables - one for `default` routes, and one for each Port. Routes are lists of {1..100, module()} tuples. When a request is processed each module in the routing table will be matched against the request (using the `match_route/3` callback function) until a match is found. If port-specific routes are provided these will be used, default routes will only be used if no port-specific routes have been added. + +Routes can be added using `riak_api_web:add_routes/1`, `riak_api_web:add_routes/2`. + +Not when implementing the `match_route/3` callback of a module, care needs to be taken when different modules support the same path but with different methods. The routes will be checked until the first `ok` match or `method_not_allowed` match, and then no further routes will be checked. Other, lower priority, routes will only be checked when `nomatch` is returned. + +In the current `riak_kv` implementation only `default` routes are set, so all HTTP/HTTPS listeners have the same functionality. + +### The `riak_api_web_handler` behaviour + +The acceptor has a standard workflow of functions for handling a request (`riak_api_web_acceptor:handle_request/5`), and included in that workflow are the calls to the six callbacks required in the `riak_api_web_handler` behaviour: + +- [`match_route/3`](#match_route); + - Passed path information, to be potentially matched against the routing needs for the module. +- [`check_permissions/5`](#check_permissions); + - Passed credential information (request headers and peer details), to be potentially screen requests based on authentication and authorisation needs. +- [`parse_query_params/2`](#parse_query_params) + - Passed any parsed query parameters included in the request for validation. +- [`parse_request_headers/2`](#parse_request_headers) + - Passed all request headers included in the request for validation. +- [`process_request/2`](#process_request) + - Passed a `riak_api_web_body:req_body/0` object (or `none`) so that the value may be fetched, and the request processed and the response returned (either as a binary or a streaming function that will incrementally generate the binary). +- [`record_request/3`](#record_request) + - Passed timing information about the request to be recorded as required. + +At each callback a `context` object is required to be returned. The format of this object is opaque to the acceptor, but the object will be forwarded as an attribute as-is to the next callback in the list. So as the request is parsed and validated through its callback functions, the module should update the `context` with any information that might be relevant to its own callback functions later in the handling of the request. + +For `check_permissions/5`, `parse_query_params/2`, `parse_request_headers/2` and `process_request` the workflow can be terminated by returning a `riak_api_web_acceptor:halt_response()` rather than a positive response. This will prompt the workflow to be immediately terminated, and a response returned with the information contained within the `halt_response()` (e.g. response code, headers and message body). + +When a `halt_response()` is returned the connection will be closed, even where a `keepalive` request has been made. + +#### match_route + +This callback function should attempt to match the module to the path, and either return: + +- an `ok` process with the size limits for the module (count of headers, maximum byte-size of an individual header, and the maximum size of the body of the request), and an initial context object for the request. +- a `nomatch` response indicating the module does not support that path (and so the next module in the route priority list should be tried). +- a `method_not_allowed` response to indicate that the path matched but the method is not in the supported list of methods for this module. + +The `match_route` callback will receive 'Method' (an atom representing the HTTP request method), 'Path' (the full path as a binary string) and `Split Path` (the full path split into a list of individual elements separated by "/"). + +e.g. `GET /types/T/buckets/B/keys/K?returnbody=true HTTP/1.1` will lead to call to: + +```erlang +match_route('GET', <<"/types/T/buckets/B/keys/K">>, [<<"types">>, <<"T">>, <<"buckets">>, <<"B">>, <<"keys">>, <<"K">>]) +``` + +The split path (list) is trimmed of any leading or trailing empty elements e.g. "/stats/" and "/stats" will be equivalent. The URL will be normalised and unquoted before calling `match_route/3` - e.g. handling any "\..\"-style directory traversal and % encoding of non-standard characters. + +All modules are tried until either an `ok` response or a `method_not_allowed` response is returned. The `method_not_allowed` response will trigger a [HTTP 405](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status/405) error response. + +#### check_permissions + +The check_permissions callback function will be passed: + +- All request headers as a `riak_api_web_headers:headers()` object which can be managed via the `riak_api_web_headers` module. +- The scheme for the listener (e.g. http or https). +- The IP address of the peer making the request. +- The client certificate used in any TLS negotiation (or `undefined` if no certificate used). +- The context object returned from the `match_route` function call. The context object should have been initiated with any details necessary to make a permission check from the path (e.g. in Riak the object Bucket). + +The check_permission should return either an `ok` response with a potentially updated context or a `halt_response()`. + +Within Riak, most check_permissions implementation should use the `riak_kv_web_common:check_permissions/5` function, to standardise the application of security controls. + +#### parse_query_params + +The query parameters will be passed as list of `{Key, Value}` tuples with the Key and Value both being binaries as they were presented in the URI (following percent decoding). If a key is provided as a parameter within the query parameters without value, the value will be the atom `true`. + +As with other callbacks, valid responses are either an `ok` with updated context object, or a `halt_response()` (for example if an invalid query parameter has been provided). + +#### parse_request_headers + +The request headers will be passed as `riak_api_web_headers:headers()` object which can be managed via the `riak_api_web_headers` module. Note this will be the same information as passed into the `check_permissions` callback. It is recommended to defer parsing non-security request headers until this stage (when permissions have already been checked), to reduce the workload undertaken on unverified requests. + +As with other callbacks, valid responses are either an `ok` with updated context object, or a `halt_response()` (for example if an invalid request header has been provided). + +The `riak_api_web_headers` module requires knowledge if the header has an `atom()` or a `binary()` as a key. The module has a `standard_header_key()` type which list all header keys which will be atoms and not binaries. For binary keys, as well as fetching individual headers by key, it is also possible to fold to return all headers with a given prefix. When fetching binary header keys, it can be specified that the header key being request has already been lower-cased (using `string:casefold/1`) so that lower-casing does not need to be repeated within the function. + +Note that headers may have single values or multiple values, check the function spec and ensure both cases are handled if required. Multiple values will occur either because the header value is a comma-separated list, or because multiple header values have been provided under a repeated header key. + +Only the 'Content-Length' and 'Transfer-Encoding' headers are parsed within the framework - to obtain a static content length, or prepare for a chunked request body. Only the transfer-encoding of `chunked` is managed within the framework. + +There is no handling of information in other request headers within Silver Machine, all other headers are only handled within the callback functions. So all headers that are expected to have meaning must be parsed and have appropriate details added to the context for downstream consideration in the `process_request/2` callback function (e.g. handling conditional headers such as 'If-None-Match', matching 'Accept' header to content-types provided, or validating `Referer' details). + +#### process_request + +The process_request callback function will be passed a `riak_api_web_body:req_body/0` object, or the atom `none` if and only if it had been stipulated in the size limits returned from the `match_route/3` callback that only a 0-length body is supported. Before providing a `none` request body, the buffer is checked by the acceptor to confirm no body has been provided (and a ['413 Content Too Large'](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status/413) response is returned if there is a body present). + +At this stage the request body, if present, has not been read from the TCP buffer, and so sending of a large body will be suspended at the client (if the TCP window is full). The acceptor (and the service it supports) is protected from a memory perspective until the body is fetched by the process_request callback function. Fetching the body is managed through `riak_api_web_body:get_body/3` function, and the body may be fetched entirely, or partially up to a size limit. Selecting the body in slices may be used if the intention is to slice and store large inbound requests without reading the whole request into memory. There is no relationship between slices and chunks - slice sizes are defined on the server side, and chunk sizes are defined on the client side. + +The `process_request/2` callback function should not return a positive response unless the entirety of the body has been read. If the reading of the body is curtailed then a `halt_response()` must be returned as otherwise the handling of further requests in a keepalive connection may be corrupted. + +A positive response must contain a response tuple as well as `ok` and the updated context object. This tuple consists of: + +```erlang +{ + riak_api_web_acceptor:response_code(), + riak_api_web_headers:header_list(), + response_body(), + boolean(), + riak_api_web_body:req_body()|none +}. +``` + +- response_code; the HTTP response code to be returned, this may be an error code as well as a positive code. If supporting pipelined requests, it may be preferable to return `404` errors as a positive response rather than as a `halt_response()` that would cause the connection to be terminated. +- response header_list; a list of Key/Value tuples representing the headers to be added to the response. Keys should be atom() if it is a standard_key, and otherwise a binary in the case it is intended to be presented. The only headers added by SilverMachine will be a 'Date' Header, a 'Server' header, a 'Connection' header and either a 'Content-Length' header or 'Transfer-Encoding` header as appropriate. Any user-provided headers that overlap with these default headers will override the defaults. +- The response body; either a binary() (in which case the response will be sent immediately with a fixed `Content-Length`), or a stream function in which case every binary returned from the stream function will be returned as a 'chunk' in a chink-encoded response (until the function returns the atom `done`). +- A keepalive supported boolean; may be switched from true to false if there is a requirement to close this connection rather than allow further requests to be received. +- the `req_body` remainder, i.e. the final `req_body` object returned from the call to `riak_api_web_body:get_body/3`. In fetching the body, when supporting pipelined requests some of a subsequent request may be read into the buffer, and returning the final req_body object ensures that this buffer is available to the acceptor to process that request. The atom `none` should be returned if the atom `none` was received as the request body. + +For an example stream function to return the body, see the `riak_kv_ag_index` module. Note that when calling `riak_api_web_body:get_body/3` the req_body object tracks the volume of data received versus the configured size limit - and may return `{error, content_too_large}` if the size is exceeded. + +#### record_request + +The record_request callback function is passed timing information from the handling of the request, as well as the request context object. This is intended to be used for any statistics or logging activity required by the module. + + +## Limitations + +Silver Machine is designed to support a subset of the HTTP protocol, the restrictions include: + +- Limited to only support HTTP 1.0 and HTTP 1.1 connections; + - HTTP 1.1 request pipelining is supported, but currently subject to limited testing. Adding multiplexed requests (i.e. HTTP 2.0) will require a significant change. +- Supported methods are limited to 'OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE' and 'TRACE' - but all functionality must exist within the callback functions of the handler modules. The framework is unaware of what method is being used (and so may return a body to a HEAD request for example). +- Only 'Content-Length' and 'Transfer-Encoding' request headers are understood by the framework, and only chunked (rather than compressed) encoding is handled automatically. Only 'Server', 'Date' and 'Connection' response headers are added by the framework, if not present in the output from the callback function. +- There is no control in the ordering of response HTTP headers, headers in the response on the wire by be returned in a different order to headers in the response returned by a callback function. +- TLS support is limited by that offered in the OTP deployment. diff --git a/rebar.config b/rebar.config index d3f77f6..a4a67e6 100644 --- a/rebar.config +++ b/rebar.config @@ -4,20 +4,51 @@ {erl_opts, [warnings_as_errors]}. +{erlfmt, [ + write, + {print_width, 80}, + {files, [ + "src/riak_api_web_acceptor.erl", + "src/riak_api_web_body.erl", + "src/riak_api_web_headers.erl", + "src/riak_api_web_security.erl", + "src/riak_api_web_socket.erl", + "src/riak_api_web.erl", + "src/riak_api_web_handler.erl" + "test/riak_api_web_get_random.erl", + "test/riak_api_web_ets_store.erl", + "test/riak_api_web_trigger.erl", + "rebar.config" + ]}, + {exclude_files, []} +]}. + +{project_plugins, [ + {erlfmt, {git, "https://github.com/OpenRiak/erlfmt.git", {branch, "main"}}} +]}. + {eunit_opts, [verbose]}. {deps, [ - {riak_pb, {git, "https://github.com/OpenRiak/riak_pb.git", {branch, "openriak-3.4"}}}, - {webmachine, {git, "https://github.com/OpenRiak/webmachine.git", {branch, "openriak-3.4"}}}, - {mochiweb, {git, "https://github.com/OpenRiak/mochiweb.git", {branch, "openriak-3.4"}}}, - {riak_core, {git, "https://github.com/OpenRiak/riak_core.git", {branch, "openriak-4.0"}}} - ]}. + {riak_pb, + {git, "https://github.com/OpenRiak/riak_pb.git", + {branch, "openriak-3.4"}}}, + {riak_core, + {git, "https://github.com/OpenRiak/riak_core.git", + {branch, "openriak-4.0"}}} +]}. {profiles, [ - {test, [{deps, [{meck, {git, "https://github.com/OpenRiak/meck.git", {branch, "openriak-3.4"}}}]}]}, + {test, [ + {deps, [ + {meck, + {git, "https://github.com/OpenRiak/meck.git", + {branch, "openriak-3.4"}}} + ]} + ]}, {gha, [{erl_opts, [{d, 'GITHUBEXCLUDE'}]}]} ]}. {dialyzer, [{plt_apps, all_deps}]}. -{xref_checks,[undefined_function_calls,undefined_functions,locals_not_used]}. +{xref_checks, [undefined_function_calls, undefined_functions, locals_not_used]}. diff --git a/src/riak_api.app.src b/src/riak_api.app.src index 5ce5af7..69700be 100644 --- a/src/riak_api.app.src +++ b/src/riak_api.app.src @@ -9,9 +9,7 @@ stdlib, ssl, riak_core, - riak_pb, - webmachine, - mochiweb + riak_pb ]}, {registered, [riak_api_sup, riak_api_pb_sup]}, diff --git a/src/riak_api_web.erl b/src/riak_api_web.erl index 8b609f6..fed15f6 100644 --- a/src/riak_api_web.erl +++ b/src/riak_api_web.erl @@ -24,76 +24,333 @@ %% of Riak. -module(riak_api_web). +-export( + [ + get_listeners/0, + binding_config/2, + add_routes/1, + add_routes/2, + get_route/4, + spec_name/3, + rfc1123_date/1, + rfc1123_date/2, + rfc1123_date_now/0, + cache_today/0 + ] +). --export([get_listeners/0, - binding_config/2]). +-type binding() :: {inet:ip_address(), inet:port_number()}. +-type route() :: {1..100, module()}. --include_lib("kernel/include/logger.hrl"). +%%%============================================================================ +%%% Routing +%%%============================================================================ + +-spec add_routes(list(route())) -> ok. +add_routes(Routes) -> + add_routes(default, Routes). + +-spec add_routes( + inet:port_number() | default, + list(route()) +) -> + ok. +add_routes(ServerName, Routes) -> + CurrentRoutes = persistent_term:get({?MODULE, ServerName}, []), + NewRoutes = lists:keysort(1, CurrentRoutes ++ Routes), + persistent_term:put({?MODULE, ServerName}, NewRoutes). + +-spec get_route( + inet:port_number(), + riak_api_web_acceptor:method(), + unicode:chardata(), + list(unicode:chardata()) +) -> + { + ok, + module(), + {pos_integer(), pos_integer(), non_neg_integer()}, + any() + } + | riak_api_web_acceptor:halt_response(). +get_route(Port, Method, Path, SplitPath) -> + CurrentRoutes = + persistent_term:get( + {?MODULE, Port}, + persistent_term:get({?MODULE, default}, []) + ), + select_route(CurrentRoutes, Method, Path, SplitPath). + +select_route([], _Method, _Path, _SP) -> + {halt, 404, [], <<>>, []}; +select_route([{_P, CallbackMod} | Rest], Method, Path, SplitPath) -> + case CallbackMod:match_route(Method, Path, SplitPath) of + nomatch -> + select_route(Rest, Method, Path, SplitPath); + {method_not_allowed, AllowedMethods} -> + AllowHdrVal = + iolist_to_binary( + lists:join( + <<", ">>, + lists:map(fun atom_to_binary/1, AllowedMethods) + ) + ), + {halt, 405, [{'Allow', AllowHdrVal}], <<>>, []}; + {ok, {MaxHdrCount, MaxHdrSize, MaxBodySize}, Context} when + MaxHdrCount > 0, MaxHdrSize > 0, MaxBodySize >= 0 + -> + {ok, CallbackMod, {MaxHdrCount, MaxHdrSize, MaxBodySize}, Context} + end. + +%%%============================================================================ +%%% Configure and Initiate Listeners +%%%============================================================================ get_listeners() -> get_listeners(http) ++ get_listeners(https). +-spec get_listeners(http | https) -> list({https | https, binding()}). get_listeners(Scheme) -> - Listeners = case app_helper:try_envs([{riak_api, Scheme}, - {riak_core, Scheme}], []) of - {riak_api, Scheme, List} when is_list(List) -> - List; - {riak_core, Scheme, List} when is_list(List) -> - ?LOG_WARNING("Setting riak_core/~s is deprecated, please use riak_api/~s", [Scheme, Scheme]), - List; - _ -> - [] - end, - lists:usort([ {Scheme, Binding} || Binding <- Listeners ]). + Listeners = application:get_env(riak_api, Scheme, []), + lists:usort([{Scheme, Binding} || Binding <- Listeners]). binding_config(Scheme, Binding) -> {Ip, Port} = Binding, Name = spec_name(Scheme, Ip, Port), Config = spec_from_binding(Scheme, Name, Binding), - {Name, - {webmachine_mochiweb, start, [Config]}, - permanent, 5000, worker, [mochiweb_socket_server]}. + { + Name, + {riak_api_web_socket, start_link, [Config]}, + permanent, + 5000, + worker, + [riak_api_web_socket] + }. spec_from_binding(http, Name, {Ip, Port}) -> - Options = - lists:flatten([{name, Name}, - {ip, Ip}, - {port, Port}, - {nodelay, true}], - common_config()), - add_recbuf(Options); + lists:flatten( + [ + {name, Name}, + {ip, Ip}, + {port, Port}, + {nodelay, true} + ], + common_config() + ); spec_from_binding(https, Name, {Ip, Port}) -> - Options = - lists:flatten([{name, Name}, - {ip, Ip}, - {port, Port}, - {ssl, true}, - {ssl_opts, riak_api_ssl:options()}, - {nodelay, true}], - common_config()), - add_recbuf(Options). - -add_recbuf(Options) -> - case application:get_env(webmachine, recbuf) of - {ok, RecBuf} -> - [{recbuf, RecBuf}|Options]; - _ -> - Options - end. + lists:flatten( + [ + {name, Name}, + {ip, Ip}, + {port, Port}, + {ssl, true}, + {ssl_opts, riak_api_ssl:options()}, + {nodelay, true} + ], + common_config() + ). spec_name(Scheme, Ip, Port) -> - FormattedIP = if is_tuple(Ip); tuple_size(Ip) == 4 -> - inet_parse:ntoa(Ip); - is_tuple(Ip); tuple_size(Ip) == 8 -> - [$[, inet_parse:ntoa(Ip), $]]; - true -> Ip - end, - lists:flatten(io_lib:format("~s://~s:~p", [Scheme, FormattedIP, Port])). + FormattedIP = + if + is_tuple(Ip); tuple_size(Ip) == 4 -> + inet_parse:ntoa(Ip); + is_tuple(Ip); tuple_size(Ip) == 8 -> + [$[, inet_parse:ntoa(Ip), $]]; + true -> + Ip + end, + iolist_to_binary(io_lib:format("~s://~s:~p", [Scheme, FormattedIP, Port])). common_config() -> - [{log_dir, app_helper:get_env(riak_api, http_logdir, - app_helper:get_env(riak_core, platform_log_dir, "log"))}, - {backlog, 128}, - {dispatch, [{[], riak_api_wm_urlmap, []} - ]}]. + [ + {log_dir, + app_helper:get_env( + riak_api, + http_logdir, + app_helper:get_env(riak_core, platform_log_dir, "log") + )}, + {backlog, 128} + ]. + +%%%============================================================================ +%%% RFC1123 Clock Management +%%%============================================================================ + +-spec cache_today() -> ok. +cache_today() -> + {Date, Time} = calendar:now_to_universal_time(os:timestamp()), + case persistent_term:get({?MODULE, cache_today}, undefined) of + {Date, _DateBin} -> + ok; + _ -> + <> = rfc1123_date(Date, Time), + persistent_term:put({?MODULE, cache_today}, {Date, DateBin}) + end. + +-spec rfc1123_date_now() -> binary(). +rfc1123_date_now() -> + {Date, Time} = calendar:now_to_universal_time(os:timestamp()), + case persistent_term:get({?MODULE, cache_today}, undefined) of + {CachedDate, DateBin} when CachedDate == Date -> + rfc1123_date(DateBin, Time); + _ -> + spawn(fun cache_today/0), + rfc1123_date(Date, Time) + end. + +-spec rfc1123_date(erlang:timestamp()) -> binary(). +rfc1123_date(TS) -> + {Date, Time} = calendar:now_to_universal_time(TS), + rfc1123_date(Date, Time). + +rfc1123_date({YYYY, MM, DD}, {Hr, Mn, Sc}) -> + DateBin = + << + (day_bin(calendar:day_of_the_week({YYYY, MM, DD})))/binary, + (i2_bin(DD))/binary, + (mon_bin(MM))/binary, + (integer_to_binary(YYYY))/binary, + <<" ">>/binary + >>, + rfc1123_date(DateBin, {Hr, Mn, Sc}); +rfc1123_date(DateBin, {Hr, Mn, Sc}) when is_binary(DateBin) -> + << + DateBin/binary, + (i2_bin(Hr))/binary, + $:, + (i2_bin(Mn))/binary, + $:, + (i2_bin(Sc))/binary, + <<" GMT">>/binary + >>. + +i2_bin(I) when I < 10 -> + <<$0, (integer_to_binary(I))/binary>>; +i2_bin(I) -> + integer_to_binary(I). + +day_bin(1) -> + <<"Mon, ">>; +day_bin(2) -> + <<"Tue, ">>; +day_bin(3) -> + <<"Wed, ">>; +day_bin(4) -> + <<"Thu, ">>; +day_bin(5) -> + <<"Fri, ">>; +day_bin(6) -> + <<"Sat, ">>; +day_bin(7) -> + <<"Sun, ">>. + +mon_bin(1) -> + <<" Jan ">>; +mon_bin(2) -> + <<" Feb ">>; +mon_bin(3) -> + <<" Mar ">>; +mon_bin(4) -> + <<" Apr ">>; +mon_bin(5) -> + <<" May ">>; +mon_bin(6) -> + <<" Jun ">>; +mon_bin(7) -> + <<" Jul ">>; +mon_bin(8) -> + <<" Aug ">>; +mon_bin(9) -> + <<" Sep ">>; +mon_bin(10) -> + <<" Oct ">>; +mon_bin(11) -> + <<" Nov ">>; +mon_bin(12) -> + <<" Dec ">>. + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +wm_rfc1123_date(TS) -> + {{YYYY, MM, DD}, {Hour, Min, Sec}} = calendar:now_to_universal_time(TS), + DayNumber = calendar:day_of_the_week({YYYY, MM, DD}), + iolist_to_binary( + lists:flatten( + io_lib:format( + "~s, ~2.2.0w ~3.s ~4.4.0w ~2.2.0w:~2.2.0w:~2.2.0w GMT", + [ + httpd_util:day(DayNumber), + DD, + httpd_util:month(MM), + YYYY, + Hour, + Min, + Sec + ] + ) + ) + ). + +date_speed_test() -> + {_, S, MicroS} = os:timestamp(), + Dates = lists:map(fun(I) -> {775 + I, S, MicroS} end, lists:seq(1, 1000)), + {TC1, DL1} = + timer:tc( + fun() -> lists:map(fun(TS) -> rfc1123_date(TS) end, Dates) end + ), + {TC2, DL2} = + timer:tc( + fun() -> lists:map(fun(TS) -> wm_rfc1123_date(TS) end, Dates) end + ), + io:format(user, "Timing for ours ~w vs wm ~w~n", [TC1, TC2]), + ?assert(DL1 == DL2), + + PreCalcDates = lists:map(fun(<>) -> D end, DL1), + NewInputs = lists:zip(PreCalcDates, Dates), + {TC3, DL3} = + timer:tc( + fun() -> + lists:map( + fun({CachedDate, TS}) -> + rfc1123_date( + CachedDate, + element(2, calendar:now_to_universal_time(TS)) + ) + end, + NewInputs + ) + end + ), + io:format(user, "With pre-cached dates ~w~n", [TC3]), + ?assert(DL1 == DL3). + +check_date_is_autocached_test() -> + persistent_term:erase({?MODULE, cache_today}), + rfc1123_date_now(), + true = + lists:foldl( + fun(I, Acc) -> + case Acc of + true -> + true; + false -> + timer:sleep(I), + not_cached =/= + persistent_term:get( + {?MODULE, cache_today}, not_cached + ) + end + end, + false, + lists:seq(1, 100) + ), + rfc1123_date_now(). + +-endif. diff --git a/src/riak_api_web_acceptor.erl b/src/riak_api_web_acceptor.erl new file mode 100644 index 0000000..4f4cd79 --- /dev/null +++ b/src/riak_api_web_acceptor.erl @@ -0,0 +1,984 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Handler for a HTTP connection, where the connection will be associated +%% With a module implementing the riak_api_web_rest behaviour + +-module(riak_api_web_acceptor). + +-if(?OTP_RELEASE == 26). +-feature(maybe_expr, enable). +-endif. + +-export([start_link/2, init/3]). + +-export([extend_buffer/4, compile_detectors/0]). + +-include_lib("kernel/include/logger.hrl"). + +-define(ACCEPT_TIMEOUT, 10000). +-define(RECEIVE_TIMEOUT, 60000). +-define(CONTINUE_RESPONSE, <<"HTTP 1.1 100 Continue">>). + +-type response_code() :: + 200..204 + | 206 + | 300..304 + | 400 + | 401..406 + | 408..418 + | 421..429 + | 431 + | 451 + | 500..508. + +-type method() :: + 'OPTIONS' | 'GET' | 'HEAD' | 'POST' | 'PUT' | 'DELETE' | 'TRACE'. + +-type http_version() :: + {1, 0} | {1, 1}. + +-type halt_response() :: + { + halt, + response_code(), + riak_api_web_headers:header_list(), + binary(), + list(term()) + }. + +-type halt_result() :: + { + halt, + response_code(), + riak_api_web_headers:header_list(), + binary(), + riak_api_web_socket:socket() + }. +-type good_result() :: + { + finish, + boolean(), + response_code(), + riak_api_web_headers:headers(), + {stream, stream_fun()} | binary(), + {module(), any()}, + riak_api_web_socket:socket(), + binary(), + pos_integer() + }. + +-type stream_fun() :: fun(() -> {ok, binary()} | done). +-type send_fun() :: fun((binary()) -> ok | {error, any()}). + +-export_type([halt_response/0, method/0, response_code/0]). + +%%%============================================================================ +%%% API +%%%============================================================================ + +-spec start_link(riak_api_web_socket:socket(), inet:port_number()) -> pid(). +start_link(Socket, Port) -> + spawn_link(?MODULE, init, [self(), Socket, Port]). + +-spec init(pid(), riak_api_web_socket:socket(), inet:port_number()) -> ok. +init(Server, Listener, Port) -> + case riak_api_web_socket:accept(Listener, ?ACCEPT_TIMEOUT) of + {ok, Socket} -> + ok = riak_api_web_socket:acceptor_accepted(Server), + {ok, PeerIP, Cert} = riak_api_web_socket:get_peer(Socket), + loop(Socket, <<>>, PeerIP, Cert, Port); + {error, timeout} -> + init(Server, Listener, Port); + {error, {tls_alert, Alert}} -> + ?LOG_WARNING("TLS Alert received ~0p", [Alert]), + init(Server, Listener, Port); + {error, closed} -> + ok; + {error, Other} -> + exit({error, Other}) + end. + +%%%============================================================================ +%%% Primary Loop +%%%============================================================================ + +-spec loop( + riak_api_web_socket:socket(), + binary(), + inet:ip_address(), + public_key:cert() | undefined, + inet:port_number() +) -> + ok. +loop(Socket, InitBuffer, PeerIP, Cert, Port) -> + %% In the keepalive loop, the send buffer is assumed to be empty + %% An so pipelining of requests (in parallel) is explicitly not supported + case handle_request(Socket, InitBuffer, PeerIP, Cert, Port) of + {KeepAlive, Buffer} when KeepAlive == true -> + loop(Socket, Buffer, PeerIP, Cert, Port); + _Close -> + riak_api_web_socket:close(Socket), + ok + end. + +-spec handle_request( + riak_api_web_socket:socket(), + binary(), + inet:ip_address(), + public_key:cert() | undefined, + inet:port_number() +) -> + {boolean(), binary()} | close. +handle_request(Socket, InitBuffer, PeerIP, Cert, Port) -> + StartTime = os:system_time(microsecond), + reset_version(), + RequestResult = + maybe + {ok, {Method, RawPath, Version, HdrBuffer}} ?= + get_request_line(Socket, InitBuffer), + set_version(Version), + {ok, {Path, SplitPath, QueryParams}} ?= split_path(RawPath), + { + ok, + CallbackMod, + {MaxHdrCount, MaxHdrSize, MaxBodySize}, + InitModCtx + } ?= + riak_api_web:get_route(Port, Method, Path, SplitPath), + {ok, ReqHeaders, BdyBuffer} ?= + get_request_headers( + HdrBuffer, + Socket, + {MaxHdrCount, MaxHdrSize} + ), + {ok, ModCtx1} ?= + CallbackMod:check_permissions( + ReqHeaders, + element(1, Socket), + PeerIP, + Cert, + InitModCtx + ), + {ok, ModCtx2} ?= + CallbackMod:parse_query_params(QueryParams, ModCtx1), + {ok, ModCtx3} ?= + CallbackMod:parse_request_headers(ReqHeaders, ModCtx2), + {ok, {CLorChunk, UseGzip}} ?= expect_body(ReqHeaders), + {ok, InitReqBody} ?= + riak_api_web_body:initiate_body( + extend_buffer_fun(Socket), + BdyBuffer, + CLorChunk, + UseGzip, + MaxBodySize + ), + ok ?= send_continue(Socket, ReqHeaders), + {ok, NextReqBody, CallbackReqBody} ?= + case MaxBodySize of + N when N == 0 -> + case riak_api_web_body:confirm_empty(InitReqBody) of + {ok, RemBody} -> + {ok, RemBody, none}; + {error, content_too_large} -> + { + halt, + 413, + [{'Content-Type', <<"text/plain">>}], + <<>>, + [] + } + end; + _N -> + {ok, none, InitReqBody} + end, + {ok, {Code, RspHeaders, RspBody, KeepAliveOK, RetBody}, ModCtx4} ?= + CallbackMod:process_request( + CallbackReqBody, + ModCtx3 + ), + {ok, BufferNext} ?= + case {NextReqBody, RetBody} of + {NextReqBody, none} when NextReqBody =/= none -> + {ok, riak_api_web_body:get_buffer(NextReqBody)}; + {none, RetBody} when RetBody =/= none -> + {ok, riak_api_web_body:get_buffer(RetBody)}; + _ -> + WarnText = + "Incorrect handling of request body buffer in" + " callback module ~w", + ?LOG_WARNING(WarnText, [CallbackMod]), + { + halt, + 500, + [{'Content-Type', <<"text/plain">>}], + <<"Error handling request body">>, + [] + } + end, + Keepalive = + request_prefers_keepalive(Version, ReqHeaders) andalso + KeepAliveOK, + MergedRspHeaders = + riak_api_web_headers:enter_from_list( + RspHeaders, + default_response_headers(Keepalive) + ), + { + finish, + Keepalive, + Code, + MergedRspHeaders, + RspBody, + {CallbackMod, ModCtx4}, + Socket, + BufferNext, + StartTime + } + else + {halt, HaltRspCode, HaltRspHeaders, HaltRspText, HaltRspSubs} -> + HaltRspBody = generate_error_body(HaltRspText, HaltRspSubs), + {halt, HaltRspCode, HaltRspHeaders, HaltRspBody, Socket} + end, + handle_response(RequestResult). + +%%%============================================================================ +%%% Manage Version on Process dictionary +%%%============================================================================ + +-define(VERSION_KEY, {?MODULE, http_version}). + +-spec set_version(http_version()) -> ok. +set_version(Version) when Version == {1, 0}; Version == {1, 1} -> + put(?VERSION_KEY, Version). + +-spec get_version() -> http_version(). +get_version() -> + case get(?VERSION_KEY) of + undefined -> + {1, 0}; + Tag -> + Tag + end. + +-spec reset_version() -> ok. +reset_version() -> + put(?VERSION_KEY, undefined). + +%%%============================================================================ +%%% Internal request handling functions +%%%============================================================================ + +-spec bad_request(binary(), list()) -> halt_response(). +bad_request(Error, Subs) -> + {halt, 400, [], Error, Subs}. + +%% @doc %% @doc Call this function when initialising API +-spec compile_detectors() -> ok. +compile_detectors() -> + CP = binary:compile_pattern([<<"%">>, <<".">>]), + persistent_term:put({?MODULE, compile_patterns}, CP). + +-spec normalise_path(binary()) -> uri_string:uri_map() | uri_string:error(). +normalise_path(URI) -> + CP = persistent_term:get({?MODULE, compile_patterns}), + case binary:match(URI, CP) of + nomatch -> + % There is no percent-encoded content, or no path reversing, and + % so it is safe to parse rather than normalise + uri_string:parse(URI); + _ -> + case uri_string:normalize(URI, [return_map]) of + URIMap when is_map(URIMap) -> + uri_string:percent_decode(URIMap); + {error, Type, Detail} -> + {error, Type, Detail} + end + end. + +-spec split_path( + binary() +) -> + { + ok, + { + unicode:chardata(), + list(unicode:chardata()), + [{unicode:chardata(), unicode:chardata() | true}] + } + } + | halt_response(). +split_path(URIPath) -> + case normalise_path(URIPath) of + URIMap when is_map(URIMap) -> + PathN = maps:get(path, URIMap, <<>>), + QueryParamsN = maps:get(query, URIMap, <<>>), + SplitPath = binary:split(PathN, <<"/">>, [global, trim_all]), + case uri_string:dissect_query(QueryParamsN) of + QueryParams when is_list(QueryParams) -> + {ok, {PathN, SplitPath, QueryParams}}; + {error, QTerm, QReason} -> + bad_request( + <<"Query parameters not parsed ~w - ~0p">>, + [QTerm, QReason] + ) + end; + {error, NTerm, NReason} -> + bad_request( + <<"Path cannot be normalized ~w - ~0p">>, + [NTerm, NReason] + ) + end. + +-spec extend_buffer( + riak_api_web_socket:socket(), + binary(), + non_neg_integer() | line, + pos_integer() | undefined +) -> + binary(). +extend_buffer(Socket, Buffer, Needed, Timeout) when is_integer(Needed) -> + case riak_api_web_socket:recv(Socket, Needed, get_timeout(Timeout)) of + {ok, Data} when is_binary(Data) -> + <>; + {error, closed} -> + riak_api_web_socket:close(Socket), + exit(normal); + {error, Reason} -> + log_unexpected_recv(Socket, Reason), + exit(normal) + end; +extend_buffer(Socket, Buffer, line, Timeout) -> + case riak_api_web_socket:recv_line(Socket, get_timeout(Timeout)) of + {ok, Data} when is_binary(Data) -> + <>; + {error, Reason} -> + log_unexpected_recv(Socket, Reason), + exit(normal) + end. + +-spec log_unexpected_recv( + riak_api_web_socket:socket(), + term() +) -> + ok | {error, term()}. +log_unexpected_recv(Socket, Reason) -> + LogText = "Unexpected failure to read data from client ~w for socket ~0p", + ?LOG_WARNING(LogText, [Reason, Socket]), + riak_api_web_socket:close(Socket). + +-spec extend_buffer_fun( + riak_api_web_socket:socket() +) -> + riak_api_web_body:buffer_fun(). +extend_buffer_fun(Socket) -> + fun(Buffer, Needed, Timeout) -> + extend_buffer(Socket, Buffer, Needed, Timeout) + end. + +-spec expect_body( + riak_api_web_headers:headers() +) -> + {ok, {non_neg_integer() | chunked, boolean()}} | halt_response(). +expect_body(Headers) -> + ContentLengthH = + riak_api_web_headers:get_unique_value('Content-Length', Headers), + Encoding = + case riak_api_web_headers:get_value('Transfer-Encoding', Headers) of + MultipleValues when is_list(MultipleValues) -> + lists:sort(MultipleValues); + SingleValue -> + SingleValue + end, + case {ContentLengthH, Encoding} of + {ValBin, Encoding} when is_binary(ValBin) -> + try + ContentLength = binary_to_integer(ValBin), + case {ContentLength, Encoding} of + {CL, undefined} when CL >= 0 -> + {ok, {CL, false}}; + {CL, <<"gzip">>} -> + {ok, {CL, true}}; + {_CL, UnsupportedEncoding} -> + bad_request( + << + "Content length provided with unsupported " + "transfer encoding ~0p" + >>, + [UnsupportedEncoding] + ) + end + catch + _:_ -> + bad_request(<<"Non-integer content length ~0p">>, [ValBin]) + end; + {undefined, <<"chunked">>} -> + {ok, {chunked, false}}; + {undefined, [<<"chunked">>, <<"gzip">>]} -> + {ok, {chunked, true}}; + {undefined, undefined} -> + % Assume no content - and set content-length to 0 + {ok, {0, false}}; + {undefined, UnexpectedEncoding} -> + UEWarn = <<"Received encoding ~0p without content length">>, + bad_request(UEWarn, [UnexpectedEncoding]); + {{error, multiple_values}, _} -> + bad_request(<<"Content has non-unique length">>, []) + end. + +-spec generate_error_body(binary(), list(any())) -> binary(). +generate_error_body(ErrorText, Subs) -> + iolist_to_binary( + io_lib:format(ErrorText, Subs) + ). + +-spec get_request_line( + riak_api_web_socket:socket(), + binary() +) -> + {ok, {method(), binary(), http_version(), binary()}} + | halt_response(). +get_request_line(Socket, Buffer) -> + case erlang:decode_packet(http_bin, Buffer, []) of + {more, _} -> + get_request_line( + Socket, + extend_buffer(Socket, Buffer, 0, undefined) + ); + {ok, {http_request, Method, {abs_path, Path}, Version}, Rest} when + is_binary(Path), is_atom(Method) + -> + case Version of + SV when SV == {1, 0}; SV == {1, 1} -> + {ok, {Method, Path, SV, Rest}}; + _USV -> + USVError = <<"Only HTTP 1.0 and 1.1 supported">>, + {halt, 505, [], USVError, []} + end; + {ok, {http_request, Method, _, _}, _Rest} when is_atom(Method) -> + bad_request(<<"Absolute path required not full or relative">>, []); + {ok, {http_error, Error}, _} -> + bad_request(<<"HTTP error on inbound request ~0p">>, [Error]); + {ok, _Unexpected, _} -> + bad_request( + <<"Unexpected request line ~0p">>, + [Buffer] + ) + end. + +-spec get_request_headers( + binary(), + riak_api_web_socket:socket(), + {pos_integer(), pos_integer()} +) -> + {ok, riak_api_web_headers:headers(), binary()} + | riak_api_web_acceptor:halt_response(). +get_request_headers(Buffer, Socket, {MaxCount, MaxSize}) -> + riak_api_web_headers:parse_request_block( + Buffer, + fun(Prev) when is_binary(Prev) -> + extend_buffer(Socket, Prev, 0, ?RECEIVE_TIMEOUT) + end, + {MaxCount, MaxSize} + ). + +-spec request_prefers_keepalive( + http_version(), + riak_api_web_headers:headers() +) -> + boolean(). +request_prefers_keepalive({1, 0}, ReqHeaders) -> + %% https://www.rfc-editor.org/rfc/rfc7230#section-6.1 + %% Note that connection options are case insensitive + case riak_api_web_headers:get_value('Connection', ReqHeaders) of + ConnectionOption when is_binary(ConnectionOption) -> + case string:casefold(ConnectionOption) of + <<"keep-alive">> -> + true; + _ -> + false + end; + _ -> + false + end; +request_prefers_keepalive({1, 1}, ReqHeaders) -> + case riak_api_web_headers:get_value('Connection', ReqHeaders) of + ConnectionOption when is_binary(ConnectionOption) -> + case string:casefold(ConnectionOption) of + <<"close">> -> + false; + _ -> + true + end; + _ -> + true + end. + +-spec get_timeout( + undefined | infinity | non_neg_integer() +) -> + non_neg_integer() | infinity. +get_timeout(undefined) -> + ?RECEIVE_TIMEOUT; +get_timeout(infinity) -> + infinity; +get_timeout(Timeout) when is_integer(Timeout), Timeout >= 0 -> + Timeout. + +%%%============================================================================ +%%% Internal response handling functions +%%%============================================================================ + +-spec handle_response( + good_result() | halt_result() +) -> + {boolean(), binary()} | close. +handle_response( + { + finish, + Keepalive, + RspCode, + RspHeaders, + {stream, StreamFun}, + {CallbackMod, Context}, + Socket, + BufferIn, + StartTime + } +) -> + RequestCompleteTime = os:system_time(microsecond), + stream_response( + RspCode, + RspHeaders, + StreamFun, + fun(B) -> riak_api_web_socket:send(Socket, B) end + ), + ResponseCompleteTime = os:system_time(microsecond), + ok = + CallbackMod:record_request( + {StartTime, RequestCompleteTime, ResponseCompleteTime}, + stream_complete, + Context + ), + {Keepalive, BufferIn}; +handle_response( + { + finish, + Keepalive, + RspCode, + RspHeaders, + RspBody, + {CallbackMod, Context}, + Socket, + BufferIn, + StartTime + } +) when is_binary(RspBody) -> + RequestCompleteTime = os:system_time(microsecond), + send_response(RspCode, RspHeaders, RspBody, Socket), + ResponseCompleteTime = os:system_time(microsecond), + ok = + CallbackMod:record_request( + {StartTime, RequestCompleteTime, ResponseCompleteTime}, + send_complete, + Context + ), + {Keepalive, BufferIn}; +handle_response({halt, RspCode, RspHeaders, RspBody, Socket}) -> + MergedRspHeaders = + riak_api_web_headers:enter_from_list( + RspHeaders, + default_response_headers(false) + ), + send_response(RspCode, MergedRspHeaders, RspBody, Socket), + close. + +-spec send_continue( + riak_api_web_socket:socket(), + riak_api_web_headers:headers() +) -> + ok | {error, term()}. +send_continue(Socket, ReqHeaders) -> + case riak_api_web_headers:lookup(<<"expect">>, ReqHeaders, true) of + {_Key, [<<"100-continue">>]} -> + riak_api_web_socket:send(Socket, ?CONTINUE_RESPONSE); + _Other -> + ok + end. + +-spec stream_response( + response_code(), + riak_api_web_headers:headers(), + stream_fun(), + send_fun() +) -> + ok. +stream_response(RspCode, RspHeaders, StreamFun, SendFun) -> + RspLine = get_response_line(get_version(), RspCode), + FinalHeaders = + riak_api_web_headers:enter( + 'Transfer-Encoding', + <<"chunked">>, + RspHeaders + ), + Metadata = riak_api_web_headers:output_response_block(FinalHeaders), + ok = + SendFun( + << + RspLine/binary, + Metadata/binary, + <<"\r\n">>/binary + >> + ), + stream_response(StreamFun, SendFun). + +stream_response(StreamFun, SendFun) -> + case StreamFun() of + {<<>>, NextFun} -> + stream_response(NextFun, SendFun); + done -> + SendFun(<<"0\r\n\r\n">>); + {Bin, NextFun} when is_binary(Bin) -> + BS = integer_to_binary(byte_size(Bin), 16), + ok = + SendFun( + << + BS/binary, + <<"\r\n">>/binary, + Bin/binary, + <<"\r\n">>/binary + >> + ), + stream_response(NextFun, SendFun) + end. + +-spec send_response( + response_code(), + riak_api_web_headers:headers(), + binary(), + riak_api_web_socket:socket() +) -> + ok | {error, any()}. +send_response(RspCode, RspHeaders, RspBody, Socket) -> + riak_api_web_socket:send( + Socket, + generate_binary_response(RspCode, RspHeaders, RspBody) + ). + +-spec generate_binary_response( + response_code(), + riak_api_web_headers:headers(), + binary() +) -> + binary(). +generate_binary_response(RspCode, RspHeaders, RspBody) -> + RspLine = get_response_line(get_version(), RspCode), + FinalHeaders = + riak_api_web_headers:enter( + 'Content-Length', + integer_to_binary(byte_size(RspBody)), + RspHeaders + ), + Metadata = riak_api_web_headers:output_response_block(FinalHeaders), + << + RspLine/binary, + Metadata/binary, + <<"\r\n">>/binary, + RspBody/binary + >>. + +%% @doc +%% For performance reasons pre-create the whole line for the most common +%% scenarios +-spec get_response_line(http_version(), response_code()) -> binary(). +get_response_line({1, 0}, 200) -> + <<"HTTP/1.0 200 OK\r\n">>; +get_response_line({1, 0}, 201) -> + <<"HTTP/1.0 201 Created\r\n">>; +get_response_line({1, 0}, 204) -> + <<"HTTP/1.0 204 No Content\r\n">>; +get_response_line({1, 1}, 200) -> + <<"HTTP/1.1 200 OK\r\n">>; +get_response_line({1, 1}, 201) -> + <<"HTTP/1.1 201 Created\r\n">>; +get_response_line({1, 1}, 204) -> + <<"HTTP/1.1 204 No Content\r\n">>; +get_response_line({1, 0}, Code) -> + iolist_to_binary( + [ + <<"HTTP/1.0 ">>, + integer_to_binary(Code), + <<" ">>, + reason_phrase(Code), + <<"\r\n">> + ] + ); +get_response_line({1, 1}, Code) -> + iolist_to_binary( + [ + <<"HTTP/1.1 ">>, + integer_to_binary(Code), + <<" ">>, + reason_phrase(Code), + <<"\r\n">> + ] + ). + +-spec default_response_headers( + boolean() +) -> + riak_api_web_headers:headers(). +default_response_headers(KeepAlive) -> + DateHeader = {'Date', riak_api_web:rfc1123_date_now()}, + ServerHeader = {'Server', <<"RiakAPI/4.0 SilverMachine">>}, + ConnectionHeader = + case KeepAlive of + true -> + {'Connection', <<"keep-alive">>}; + false -> + {'Connection', <<"close">>} + end, + riak_api_web_headers:make_rsp_header( + [ServerHeader, DateHeader, ConnectionHeader] + ). + +%% @doc +%% The http_util:reason_phrase/1 returns Object Not Found not Not Found +%% these are taken direct from RFC 2616. Likewise "Request Entity Too Large" +%% rather than the more common "Content Too Large" +-spec reason_phrase(response_code()) -> binary(). +reason_phrase(404) -> <<"Not Found">>; +reason_phrase(413) -> <<"Content Too Large">>; +reason_phrase(431) -> <<"Request Header Fields Too Large">>; +reason_phrase(N) -> httpd_util:reason_phrase(N). + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). +-include_lib("stdlib/include/assert.hrl"). + +request_line_decode_test() -> + ?assertMatch( + {halt, 400, [], <<"Absolute path required not full or relative">>, []}, + get_request_line( + test_socket, + <<"GET no-leading-slash/relative HTTP/1.1\r\n">> + ) + ), + ?assertMatch( + {halt, 400, [], <<"Absolute path required not full or relative">>, []}, + get_request_line( + test_socket, + <<"GET http://localhost:8000/full-path HTTP/1.1\r\n">> + ) + ), + ?assertMatch( + {halt, 400, [], <<"Absolute path required not full or relative">>, []}, + get_request_line( + test_socket, + <<"GET @ref HTTP/1.1\r\n">> + ) + ), + ?assertMatch( + { + halt, + 400, + [], + <<"HTTP error on inbound request ~0p">>, + [<<"GET @ref HTP/1.1\r\n">>] + }, + get_request_line( + test_socket, + <<"GET @ref HTP/1.1\r\n">> + ) + ), + ?assertMatch( + {halt, 505, [], <<"Only HTTP 1.0 and 1.1 supported">>, []}, + get_request_line(test_socket, <<"GET /stats HTTP/2.0\r\n">>) + ), + % If the method is not supported at all, then give general error - as it is + % not possible to know what methods are allowed on the URL - this can only + % be determined when matching routes + ?assertMatch( + { + halt, + 400, + [], + <<"Unexpected request line ~0p">>, + [<<"PATCH /stats HTTP/1.0\r\n">>] + }, + get_request_line(test_socket, <<"PATCH /stats HTTP/1.0\r\n">>) + ). + +simple_response_test() -> + ok = riak_api_web:cache_today(), + set_version({1, 1}), + FullResponse = + generate_binary_response( + 200, + default_response_headers(false), + <<"OutputOK">> + ), + Date = list_to_binary(httpd_util:rfc1123_date()), + ExpectedResponse = + << + <<"HTTP/1.1 200 OK\r\n">>/binary, + <<"Connection: close\r\n">>/binary, + <<"Date: ">>/binary, + Date/binary, + <<"\r\n">>/binary, + <<"Server: RiakAPI/4.0 SilverMachine\r\n">>/binary, + <<"Content-Length: 8\r\n">>/binary, + <<"\r\n">>/binary, + <<"OutputOK">>/binary + >>, + ?assertMatch(ExpectedResponse, FullResponse). + +simple_stream_test() -> + ok = riak_api_web:cache_today(), + SendFun = + fun(Bin) when is_binary(Bin) -> + case get({?MODULE, ?TEST, send_buffer}) of + AccBin when is_binary(AccBin) -> + put( + {?MODULE, ?TEST, send_buffer}, + <> + ); + undefined -> + put({?MODULE, ?TEST, send_buffer}, Bin) + end, + ok + end, + put({?MODULE, ?TEST, send_buffer}, undefined), + Me = self(), + spawn( + fun() -> + Me ! <<"Wiki">>, + Me ! <<"Pedia ">>, + Me ! <<"in chunks!">>, + Me ! done + end + ), + Date = list_to_binary(httpd_util:rfc1123_date()), + stream_response( + 200, + default_response_headers(true), + stream_fun(), + SendFun + ), + Response = get({?MODULE, ?TEST, send_buffer}), + ExpectedResponse = + << + <<"HTTP/1.1 200 OK\r\n">>/binary, + <<"Connection: keep-alive\r\n">>/binary, + <<"Date: ">>/binary, + Date/binary, + <<"\r\n">>/binary, + <<"Transfer-Encoding: chunked\r\n">>/binary, + <<"Server: RiakAPI/4.0 SilverMachine\r\n">>/binary, + <<"\r\n">>/binary, + << + "4\r\nWiki\r\n6\r\nPedia " + "\r\nA\r\nin chunks!\r\n0\r\n\r\n" + >>/binary + >>, + ?assertMatch(ExpectedResponse, Response). + +stream_fun() -> + fun() -> + receive + Bin when is_binary(Bin) -> + {Bin, stream_fun()}; + done -> + done + end + end. + +normalise_path_test() -> + compile_detectors(), + URI1 = <<"types/BT/buckets/B/keys/K?return_terms">>, + URI2 = <<"types/BT/buckets/../buckets/B/key%73/K?return_term%73">>, + {ok, Output1} = split_path(URI1), + {ok, Output2} = split_path(URI2), + ?assertMatch(Output1, Output2), + URI3 = <<"types/T/buckets/Swedes/keys/%C3%85berg?return_terms">>, + {ok, {_, SP, _}} = split_path(URI3), + [<<"types">>, <<"T">>, <<"buckets">>, <<"Swedes">>, <<"keys">>, Name] = SP, + ?assertMatch(<<"Ã…berg"/utf8>>, Name). + +expect_test() -> + FixedLength = + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>} + ] + ), + ?assertMatch({ok, {1024, false}}, expect_body(FixedLength)), + Empty = riak_api_web_headers:make([]), + % e.g. just curl GET from command line - no encoding or content-length + ?assertMatch({ok, {0, false}}, expect_body(Empty)), + FixedLengthGZ = + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>}, + {'Transfer-Encoding', <<"gzip">>} + ] + ), + ?assertMatch({ok, {1024, true}}, expect_body(FixedLengthGZ)), + UnsupportedCompress = + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>}, + {'Transfer-Encoding', <<"deflate">>} + ] + ), + {halt, 400, [], Error1, _} = expect_body(UnsupportedCompress), + ?assertNotMatch( + nomatch, + string:find(Error1, <<"unsupported transfer encoding">>) + ), + NoLength = + riak_api_web_headers:make( + [ + {'Transfer-Encoding', <<"gzip">>} + ] + ), + {halt, 400, [], Error2, _} = expect_body(NoLength), + ?assertNotMatch( + nomatch, + string:find(Error2, <<"without content length">>) + ), + ContentSmuggle = + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>}, + {'Transfer-Encoding', <<"gzip">>}, + {'Content-Length', <<"262144">>} + ] + ), + {halt, 400, [], Error3, _} = expect_body(ContentSmuggle), + ?assertNotMatch( + nomatch, + string:find(Error3, <<"non-unique length">>) + ). + +-endif. diff --git a/src/riak_api_web_body.erl b/src/riak_api_web_body.erl new file mode 100644 index 0000000..6406f11 --- /dev/null +++ b/src/riak_api_web_body.erl @@ -0,0 +1,560 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2007-2009 Basho Technologies +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Handling functions for receiving and sending object bodies over HTTP +%% +%% Handling of chunked requests, and some other parts inspired by webmachine. +%% +%% It is possible to accept the inbound request in slices. If there is a fixed +%% content length this will read off the receiver buffer a slice of data at a +%% time. If the transfer encoding is chunked it will buffer the greater of the +%% slice length and the chunk length - i.e. sending chunks > than the slice +%% length will require more memory. + +-module(riak_api_web_body). + +-export( + [ + get_buffer/1, + initiate_body/5, + get_body/3, + confirm_empty/1, + is_gzip/1 + ] +). + +-record(req_body, { + buffer :: binary(), + content_length :: non_neg_integer() | chunked, + gzip :: boolean(), + acc_size = 0 :: non_neg_integer(), + max_size :: pos_integer(), + buffer_fun :: buffer_fun(), + chunk_buff = <<>> :: binary(), + % Receive buffer used when chunk encoding + % if slice length is all, all body is accumulated here, and if slice + % length is an integer a slice will be extracted if the function is + % called when the chunk_buff is greater than or equal to the slice + % length + transfer_complete = false :: boolean(), + + spoof_socket = false :: boolean(), + test_packets = [] :: list(binary()) + % only used in tests +}). + +-type req_body() :: #req_body{}. +-type fetch_req() :: pos_integer() | line. + +-type buffer_fun() :: + fun((binary(), fetch_req(), non_neg_integer() | undefined) -> binary()). + +-export_type([req_body/0, buffer_fun/0]). + +%%%============================================================================ +%%% API +%%%============================================================================ + +-spec is_gzip(req_body()) -> boolean(). +is_gzip(ReqBody) -> + ReqBody#req_body.gzip. + +-spec get_buffer(req_body()) -> binary(). +get_buffer(ReqBody) -> + ReqBody#req_body.buffer. + +-spec initiate_body( + buffer_fun(), + binary(), + chunked | non_neg_integer(), + boolean(), + non_neg_integer() +) -> + {ok, req_body()}. +initiate_body(BufferFun, BdyBuffer, CLorChunk, UseGzip, MaxBodySize) -> + { + ok, + #req_body{ + buffer = BdyBuffer, + content_length = CLorChunk, + gzip = UseGzip, + max_size = MaxBodySize, + buffer_fun = BufferFun + } + }. + +-spec confirm_empty( + riak_api_web_body:req_body() +) -> + {ok, riak_api_web_body:req_body()} | {error, content_too_large}. +confirm_empty(ReqBody) -> + case riak_api_web_body:get_body(ReqBody, all, 10000) of + {done, UpdBody} -> + {ok, UpdBody}; + {<<>>, UpdBody} -> + confirm_empty(UpdBody); + {error, content_too_large} -> + {error, content_too_large} + end. + +-spec get_body( + req_body(), all | pos_integer(), pos_integer() | undefined +) -> + {binary() | done, req_body()} | {error, content_too_large}. +get_body(#req_body{content_length = CL, max_size = MS}, _SL, _TO) when + is_integer(CL), CL > MS +-> + {error, content_too_large}; +get_body(#req_body{content_length = CL, acc_size = AS} = RqBdy, _SL, _TO) when + is_integer(CL), CL == AS +-> + {done, RqBdy}; +get_body( + #req_body{content_length = CL, transfer_complete = TC} = RqBdy, + _SL, + _TO +) when CL == chunked, TC -> + {done, RqBdy}; +get_body( + #req_body{content_length = CL, acc_size = AccSize, buffer = Bin} = RqBdy, + all, + TO +) when is_integer(CL) -> + case byte_size(Bin) + AccSize of + AccSize0 when AccSize0 >= CL -> + <> = Bin, + {ReqBody, RqBdy#req_body{buffer = Rest, acc_size = CL}}; + AccSize0 -> + get_body(extend_buffer(RqBdy, CL - AccSize0, TO), all, TO) + end; +get_body( + #req_body{content_length = CL, acc_size = AccSize, buffer = Bin} = RqBdy, + SL, + TO +) when is_integer(CL), is_integer(SL) -> + case CL - AccSize of + Remaining when Remaining =< SL -> + case byte_size(Bin) of + BS when BS >= Remaining -> + <> = Bin, + {SliceBody, RqBdy#req_body{buffer = Rest, acc_size = CL}}; + BS -> + get_body(extend_buffer(RqBdy, Remaining - BS, TO), SL, TO) + end; + _Remaining -> + case byte_size(Bin) of + BS when BS >= SL -> + <> = Bin, + { + SliceBody, + RqBdy#req_body{ + buffer = Rest, + acc_size = RqBdy#req_body.acc_size + SL + } + }; + BS -> + get_body(extend_buffer(RqBdy, SL - BS, TO), SL, TO) + end + end; +get_body( + #req_body{content_length = CL, chunk_buff = ChunkBuff} = RqBdy, + SL, + _TO +) when CL == chunked, is_integer(SL), byte_size(ChunkBuff) >= SL -> + <> = ChunkBuff, + {Slice, RqBdy#req_body{chunk_buff = ChunkBuffRem}}; +get_body( + #req_body{content_length = CL, max_size = MS, acc_size = AS} = RqBdy, + SL, + TO +) when CL == chunked -> + case erlang:decode_packet(line, RqBdy#req_body.buffer, []) of + {ok, <<"\r\n">>, Rest} when is_binary(Rest) -> + get_body( + extend_buffer( + RqBdy#req_body{buffer = Rest}, + line, + TO + ), + SL, + TO + ); + {ok, Line, Rest} when is_binary(Line) -> + ChunkSize = get_chunk_size(Line), + RcvBuffer = RqBdy#req_body.chunk_buff, + case {ChunkSize, ChunkSize + AS} of + {0, _} -> + FinalRqBdy = + case Rest of + <<>> -> + extend_buffer( + RqBdy#req_body{buffer = <<>>}, + line, + TO + ); + Rest when is_binary(Rest) -> + RqBdy#req_body{buffer = Rest} + end, + <<"\r\n", Next/binary>> = get_buffer(FinalRqBdy), + { + RcvBuffer, + FinalRqBdy#req_body{ + buffer = Next, + chunk_buff = <<>>, + transfer_complete = true + } + }; + {N, NextSize} when N > 0, NextSize =< MS -> + case byte_size(Rest) of + BS when BS >= ChunkSize -> + <> = + Rest, + get_body( + RqBdy#req_body{ + buffer = FurtherChunks, + chunk_buff = + <>, + acc_size = AS + ChunkSize + }, + SL, + TO + ); + BS -> + Needed = ChunkSize - BS, + UpdRqBdy = + extend_buffer( + RqBdy#req_body{buffer = Rest}, + Needed, + TO + ), + Chunk = get_buffer(UpdRqBdy), + get_body( + UpdRqBdy#req_body{ + buffer = <<>>, + chunk_buff = + <>, + acc_size = AS + ChunkSize + }, + SL, + TO + ) + end; + {_N, _TooBig} -> + {error, content_too_large} + end; + {more, _} -> + get_body(extend_buffer(RqBdy, line, TO), SL, TO) + end. + +-spec get_chunk_size(binary()) -> non_neg_integer(). +get_chunk_size(Line) -> + case binary:split(string:trim(Line), <<";">>) of + [ChunkLength] -> + binary_to_integer(ChunkLength, 16); + [ChunkLength, _Ignore] -> + % There may be a chunk extension after a semi-colon for + % progress tracking. + % However, these are not expected in our use case, and + % could hide security issues - so will ignore. + binary_to_integer(ChunkLength, 16) + end. + +-spec extend_buffer( + req_body(), + pos_integer() | line, + non_neg_integer() | undefined +) -> + req_body(). +extend_buffer( + #req_body{buffer_fun = BufferFun, spoof_socket = false} = ReqBody, + Size, + Timeout +) -> + ReqBody#req_body{ + buffer = BufferFun(ReqBody#req_body.buffer, Size, Timeout) + }; +extend_buffer(#req_body{spoof_socket = true} = ReqBody, Size, _Timeout) -> + {NextBin, RestPackets} = + accrue_packets( + ReqBody#req_body.test_packets, + Size, + ReqBody#req_body.buffer + ), + ReqBody#req_body{buffer = NextBin, test_packets = RestPackets}. + +%% @doc accrue_packets for unit tests only, when #req_body{spoof_socket = true} +accrue_packets(Rest, 0, Buffer) -> + {Buffer, Rest}; +accrue_packets([], line, Buffer) -> + {Buffer, []}; +accrue_packets([NextPacket | Rest], line, Buffer) -> + case erlang:decode_packet(line, NextPacket, []) of + {ok, Line, Overhang} -> + {<>, [Overhang | Rest]}; + {more, _} -> + accrue_packets(Rest, line, <>) + end; +accrue_packets([NextPacket | Rest], Size, Buffer) when is_integer(Size) -> + case Size of + Needed when Needed < byte_size(NextPacket) -> + <> = NextPacket, + {<>, [RestPacket | Rest]}; + Needed -> + accrue_packets( + Rest, + Needed - byte_size(NextPacket), + <> + ) + end. + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +slicing_fixed_length_test() -> + %% Receive a 11KB body in 1KB packets + %% Slicing into 2 4KB portions, and 1 3KB + Body = crypto:strong_rand_bytes(11 * 1024), + Packets = packet_testbin(Body, []), + RqBdyInit = + #req_body{ + buffer = <<>>, + content_length = 11 * 1024, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = Packets + }, + {Slice1, RqBdy1} = get_body(RqBdyInit, 4 * 1024, 60 * 1000), + {Slice2, RqBdy2} = get_body(RqBdy1, 4 * 1024, 60 * 1000), + {Slice3, RqBdy3} = get_body(RqBdy2, 4 * 1024, 60 * 1000), + ?assertMatch(4096, byte_size(Slice1)), + ?assertMatch(4096, byte_size(Slice2)), + ?assertMatch(3072, byte_size(Slice3)), + CompleteResult = <>, + ?assertMatch(Body, CompleteResult), + ?assertMatch(<<>>, get_buffer(RqBdy3)), + ?assertMatch(done, element(1, get_body(RqBdy3, 4 * 1024, 60 * 1000))), + + %% Request the full content-length in one shot + {AllBin, RqBody4} = get_body(RqBdyInit, all, 60 * 1000), + ?assertMatch(AllBin, Body), + ?assertMatch(<<>>, get_buffer(RqBody4)), + + % Start with some of the first packet on the buffer, and end with + % some of a pipelined request in the buffer + [FirstPacket | RestPackets] = Packets, + <> = FirstPacket, + DummyRequest = crypto:strong_rand_bytes(64), + RqBdyAlt0 = + #req_body{ + buffer = OnBuffer, + content_length = 11 * 1024, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = [OnSocket | RestPackets] ++ [DummyRequest] + }, + {SliceAlt1, RqBdyAlt1} = get_body(RqBdyAlt0, 4 * 1024, 60 * 1000), + {SliceAlt2, RqBdyAlt2} = get_body(RqBdyAlt1, 4 * 1024, 60 * 1000), + {SliceAlt3, RqBdyAlt3} = get_body(RqBdyAlt2, 4 * 1024, 60 * 1000), + ?assertMatch(4096, byte_size(SliceAlt1)), + ?assertMatch(4096, byte_size(SliceAlt2)), + ?assertMatch(3072, byte_size(SliceAlt3)), + CompleteResult = <>, + ?assertMatch( + Body, + <> + ), + SocketBin = iolist_to_binary(RqBdyAlt3#req_body.test_packets), + Remainder = <<(RqBdyAlt3#req_body.buffer)/binary, SocketBin/binary>>, + ?assertMatch(DummyRequest, Remainder). + +all_in_buffer_test() -> + Body = crypto:strong_rand_bytes(11 * 1024), + RqBdyInit = + #req_body{ + buffer = Body, + content_length = 11 * 1024, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = [] + }, + {Slice1, RqBdy1} = get_body(RqBdyInit, 4 * 1024, 60 * 1000), + {Slice2, RqBdy2} = get_body(RqBdy1, 4 * 1024, 60 * 1000), + {Slice3, RqBdy3} = get_body(RqBdy2, 4 * 1024, 60 * 1000), + ?assertMatch(4096, byte_size(Slice1)), + ?assertMatch(4096, byte_size(Slice2)), + ?assertMatch(3072, byte_size(Slice3)), + CompleteResult = <>, + ?assertMatch(Body, CompleteResult), + ?assertMatch(<<>>, get_buffer(RqBdy3)), + ?assertMatch(done, element(1, get_body(RqBdy3, 4 * 1024, 60 * 1000))). + +get_empty_body_test() -> + RqBdyInit = + #req_body{ + buffer = <<"0\r\n\r\n">>, + content_length = chunked, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = [] + }, + {Output, RqBdyEnd} = get_body(RqBdyInit, all, 1000), + ?assertMatch(<<>>, Output), + ?assertMatch(<<>>, get_buffer(RqBdyEnd)). + +get_empty_body_with_pipelined_request_test() -> + RqBdyInit = + #req_body{ + buffer = <<"0\r\n\r\nGET /stats HTTP/1.1\r\n">>, + content_length = chunked, + max_size = 1024 * 1024, + test_packets = [] + }, + {Output, RqBdyEnd} = get_body(RqBdyInit, all, 1000), + ?assertMatch(<<>>, Output), + ?assertMatch(<<"GET /stats HTTP/1.1\r\n">>, get_buffer(RqBdyEnd)). + +get_standard_wikipedia_test() -> + Packets = + [ + <<"4\r\n">>, + <<"Wiki\r\n">>, + <<"5\r\n">>, + <<"pedia\r\n">>, + <<"e\r\n">>, + <<" in\r\n\r\nchunks.\r\n">>, + <<"0\r\n">>, + <<"\r\n">> + ], + RqBdyInit = + #req_body{ + buffer = <<"">>, + content_length = chunked, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = Packets + }, + {Output, RqBdyEnd} = get_body(RqBdyInit, all, 1000), + ?assertMatch(<<"Wikipedia in\r\n\r\nchunks.">>, Output), + ?assertMatch(<<>>, get_buffer(RqBdyEnd)). + +get_standard_wikipedia_inslices_test() -> + Packets = + [ + <<"4\r\n">>, + <<"Wiki\r\n">>, + <<"5\r\n">>, + <<"pedia\r\n">>, + <<"e\r\n">>, + <<" in\r\n\r\nchunks.\r\n">>, + <<"0\r\n">>, + <<"\r\n">> + ], + RqBdyInit = + #req_body{ + buffer = <<"">>, + content_length = chunked, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = Packets + }, + {Slice1, RqBdy1} = get_body(RqBdyInit, 5, 1000), + ?assertMatch(<<"Wikip">>, Slice1), + {Slice2, RqBdy2} = get_body(RqBdy1, 5, 1000), + ?assertMatch(<<"edia ">>, Slice2), + {Slice3, RqBdy3} = get_body(RqBdy2, 100, 1000), + ?assertMatch(<<"in\r\n\r\nchunks.">>, Slice3), + ?assertMatch({done, RqBdy3}, get_body(RqBdy3, 5, 1000)). + +get_wikipedia_from_buffer_test() -> + <<>> = dummy_extend_fun(<<>>, none, none), + {ok, RqBdyInit} = + initiate_body( + fun dummy_extend_fun/3, + <<"4\r\nWiki\r\n5\r\npedia\r\ne\r\n in\r\n\r\nchunks.\r\n">>, + chunked, + false, + 1024 * 1024 + ), + OtherPackets = [<<"0\r\n">>, <<"\r\n">>], + RqBdy = + RqBdyInit#req_body{spoof_socket = true, test_packets = OtherPackets}, + {Output, RqBdyEnd} = get_body(RqBdy, all, 1000), + ?assertMatch(<<"Wikipedia in\r\n\r\nchunks.">>, Output), + ?assertMatch(<<>>, get_buffer(RqBdyEnd)). + +dummy_extend_fun(B, _, _) when is_binary(B) -> B. + +ignore_extension_test() -> + Packets = + [ + <<"4;ext\r\n">>, + <<"Wiki\r\n">>, + <<"5;somert">>, + <<"\r\n">>, + <<"pedia\r\n">>, + <<"e\r\n">>, + <<" in\r\n\r\nchunks.\r\n">>, + <<"0;other\r\n">>, + <<"\r\n">> + ], + RqBdyInit = + #req_body{ + buffer = <<>>, + content_length = chunked, + max_size = 1024 * 1024, + spoof_socket = true, + test_packets = Packets + }, + {Output, RqBdyEnd} = get_body(RqBdyInit, all, 1000), + ?assertMatch(<<"Wikipedia in\r\n\r\nchunks.">>, Output), + ?assertMatch(<<>>, get_buffer(RqBdyEnd)). + +toobig_chunking_test() -> + Packets = + [ + <<"4\r\n">>, + <<"Wiki\r\n">>, + <<"5\r\n">>, + <<"pedia\r\n">>, + <<"e\r\n">>, + <<" in\r\n\r\nchunks.\r\n">>, + <<"0\r\n">>, + <<"\r\n">> + ], + RqBdyInit = + #req_body{ + buffer = <<"">>, + content_length = chunked, + max_size = 20, + spoof_socket = true, + test_packets = Packets + }, + ?assertMatch({error, content_too_large}, get_body(RqBdyInit, all, 1000)). + +packet_testbin(<<>>, Acc) -> + lists:reverse(Acc); +packet_testbin(<>, Acc) -> + packet_testbin(Rest, [Bin | Acc]). + +-endif. diff --git a/src/riak_api_web_handler.erl b/src/riak_api_web_handler.erl new file mode 100644 index 0000000..ba88ec1 --- /dev/null +++ b/src/riak_api_web_handler.erl @@ -0,0 +1,189 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Behaviour definition for a web handler +%% +%% the callbacks will be called in the following order, with the context +%% returned from the previous call included in the next +%% - match_route/3 +%% - check_permissions/4 +%% - parse_query_params/2 +%% - parse_request_headers/2 +%% - process_request/2 +%% - record_request/3 + +-module(riak_api_web_handler). + +-type context() :: term(). + + +-type max_header_count() :: pos_integer(). + %% The maximum number of headers that will be parsed + %% A header split over multiple lines will be counted once for each line + %% e.g. + %% X-Riak-Index_field1_bin : value1 + %% X-Riak-Index_field1_bin : value2 + %% + %% Will count as two headers +-type max_header_size() :: pos_integer(). + %% The maximum size of a single header value. If concatenating multiple + %% values causes issues with this limit - the may be split across headers. +-type max_body_size() :: non_neg_integer(). + %% The maximum size of the body (on the wire) i.e. prior to being unzipped + %% if compression is allowed. Should be set to 0 if no request body is + %% expected +-type limits() :: {max_header_count(), max_header_size(), max_body_size()}. + +-export_type( + [ + limits/0, + peer_ip/0, + peer_cert/0, + query_params/0, + stream_fun/0, + response_body/0, + timings/0, + completion/0 + ] +). + +%% @doc match_route for the module +%% When called each route handled by this module must be checked, and either +%% `nomatch` returned should none match - or the initial context with the +%% limits for that route. +-callback match_route( + riak_api_web_acceptor:method(), + unicode:chardata(), + list(unicode:chardata()) +) -> + nomatch | + {method_not_allowed, list(riak_api_web_acceptor:method())} | + {ok, limits(), context()}. + +-type peer_ip() :: inet:ip_address(). + %% The IP address of the client device connected to the socket +-type peer_cert() :: public_key:cert() | undefined. + +%% @doc check_permissions for using this module or route +%% The context() passed will be the context() returned from match_route/2 - so +%% if route information is required for permissions checks, it should be added +%% to the context. +%% +%% On failure return a halt_response with e.g. 401 /403 response codes +-callback + check_permissions( + riak_api_web_headers:headers(), + riak_api_web_socket:scheme(), + peer_ip(), + peer_cert(), + context() + ) -> + {ok, context()}|riak_api_web_acceptor:halt_response(). + + +-type query_params() :: [{unicode:chardata(), unicode:chardata()|true}]. + +%% @doc parse and validate query params, passed as a map +%% Any parameter will have both key and value as a binary, except if the +%% parameter had no value - in which case the value will be the atom `true` +-callback + parse_query_params( + query_params(), + context() + ) -> + {ok, context()}|riak_api_web_acceptor:halt_response(). + +%% @doc parse and validate the request headers +-callback + parse_request_headers( + riak_api_web_headers:headers(), + context() + ) -> + {ok, context()}|riak_api_web_acceptor:halt_response(). + +-type stream_fun() :: fun(() -> {binary(), stream_fun()}|done). +-type response_body() :: + binary() | {stream, stream_fun()}. + +%% @doc Process the request and produce a response +%% The request may receive an object body, the request body element is a +%% riak_api_web_body:req_body() record. Calling riak_api_web_body:get_body/3 +%% will return the body, either in whole or one slice at a time (by setting a +%% slice length as the second attribute of the get_body/3 function, and +%% re submitting the req_body() returned into subsequent get_body/3 calls). +%% +%% Thw headers in the response need not contain the following header elements +%% which will be generated automatically: +%% - 'Server' +%% - 'Date' +%% - 'Connection' +%% - 'Content-Length'/'Transfer-Encoding' +%% +%% The response_body() may either be a binary to be sent with a fixed content +%% length, or a stream_fun() where calls to the stream_fun() will produce +%% either: +%% - a binary() chunk and an updated stream_fun() +%% - the atom() done +%% +%% The response object may be gzipped - the callback function should handle +%% this, or error as appropriate. the riak_api_web_body:is_gzip/1 function can +%% be checked to see if the object is gzipped. +%% +%% Each binary() returned from the stream_fun() will be sent as a chunk in the +%% response. +%% +%% The KeepAliveOK boolean() indicates if it is OK to reuse this connection. +%% Validation of the version and request headers is not required, this is +%% performed by the acceptor if the callback indicates that keepalive is +%% acceptable. +%% +%% The final req_body() must also be returned, so that any remaining data on +%% the buffer is available to the acceptor. If the size_limit on the request +%% is set to 0, then a req_body() of none will be sent and should be returned. +%% If a non-zero request body is expected the whole body should be read from +%% the buffer before returning the updated request body object. +-callback + process_request( + riak_api_web_body:req_body()|none, + context() + ) -> + { + ok, + { + riak_api_web_acceptor:response_code(), + riak_api_web_headers:header_list(), + response_body(), + boolean(), + riak_api_web_body:req_body()|none + }, + context() + } | riak_api_web_acceptor:halt_response(). + +-type timings() :: {non_neg_integer(), non_neg_integer(), non_neg_integer()}. + % The result of os:system_time(microsecond) for + % - the start of the request (after accepting a connection, but prior to + % receiving and routing the request) + % - the completion of receipt and processing the request, and calling + % process_request/2. + % - the completion of sending the response to the socket +-type completion() :: stream_complete | send_complete. + % was the output sent chunk encoded, or sent as a whole body + +%% @doc Record the output of the interaction +-callback record_request(timings(), completion(), context()) -> ok. \ No newline at end of file diff --git a/src/riak_api_web_headers.erl b/src/riak_api_web_headers.erl new file mode 100644 index 0000000..1ecc6b4 --- /dev/null +++ b/src/riak_api_web_headers.erl @@ -0,0 +1,606 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2007 Mochi Media, Inc +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Case preserving (but case insensitive) HTTP Header dictionary. +%% +%% The headers are stored in a map, and the header keys will be an atom if +%% in the standard list of headers decoded by Erlang/OTP - and otherwise a +%% binary(). +%% +%% The values will always be binaries, comma(-and-space)-separated for values +%% with multiple items +%% +%% The module was initially a refactoring of the mochiweb_headers module. + +-module(riak_api_web_headers). +-export([make/1, make_rsp_header/1]). +-export([enter_from_list/2, default_from_list/2, enter/3]). +-export([get_value/2, get_unique_value/2, lookup/3, prefix_fold/3]). +-export([parse_primary_header_value/1]). +-export([output_response_block/1, parse_request_block/3]). +-export([compile_separators/0]). + +-define(KV_SEPARATOR, <<": ">>). +-define(V_SEPARATOR, <<", ">>). +-define(L_SEPARATOR, <<"\r\n">>). +-define(OWS, [<<" ">>, <<"\t">>]). + +-record(headers, { + type = request :: request | response, + %% response headers do not support the lookup of non-standard + %% header keys - and hence avoid the need to lower case those + %% keys for comparison + header_map = maps:new() :: header_map() +}). + +-type standard_header_key() :: + 'Cache-Control' + | 'Connection' + | 'Date' + | 'Pragma' + | 'Transfer-Encoding' + | 'Upgrade' + | 'Via' + | 'Accept' + | 'Accept-Charset' + | 'Accept-Encoding' + | 'Accept-Language' + | 'Authorization' + | 'Proxy-Authorization' + | 'Proxy-Authenticate' + | 'Www-Authenticate' + | 'From' + | 'Host' + | 'If-Modified-Since' + | 'If-Match' + | 'If-None-Match' + | 'If-Range' + | 'If-Unmodified-Since' + | 'Max-Forwards' + | 'Range' + | 'Referer' + | 'User-Agent' + | 'Age' + | 'Location' + | 'Public' + | 'Retry-After' + | 'Server' + | 'Vary' + | 'Warning' + | 'Allow' + | 'Content-Base' + | 'Content-Encoding' + | 'Content-Language' + | 'Content-Length' + | 'Content-Location' + | 'Content-Md5' + | 'Content-Range' + | 'Content-Type' + | 'Etag' + | 'Expires' + | 'Last-Modified' + | 'Accept-Ranges' + | 'Set-Cookie' + | 'Set-Cookie2' + | 'X-Forwarded-For' + | 'Cookie' + | 'Keep-Alive' + | 'Proxy-Connection'. +% This list is controlled by Erlang/OTP - i.e. there may be further atoms +% added in the future, but it has been stable since OTP 13. +-type binary_header_key() :: unicode:chardata() | binary(). +-type header_key() :: standard_header_key() | binary_header_key(). +-type header_value() :: {binary(), list(binary())}. +-type header_map() :: #{header_key() => header_value()}. +-type header_list() :: [{header_key(), binary()}]. +-type headers() :: #headers{}. +-type buffer_fun() :: fun((binary()) -> binary()). + +-export_type([headers/0, header_list/0]). + +%%%============================================================================ +%%% API +%%%============================================================================ + +%% @doc +%% Construct a headers() from the given list of headers received in a +%% request. +-spec make([{header_key(), binary()}]) -> headers(). +make(HeaderList) when is_list(HeaderList) -> + HeaderMap = from_list(HeaderList, true), + #headers{header_map = HeaderMap}. + +%% @doc +%% Specific constructor when forming response headers. +%% With response headers it is not possible to lookup non-standard header keys, +%% An the value may be a list if elements - that will be joined into a single +%% comma-separated value before creating the response header. +-spec make_rsp_header([{header_key(), list(binary()) | binary()}]) -> + headers(). +make_rsp_header(HeaderList) -> + HeaderMap = from_list(HeaderList, false), + #headers{type = response, header_map = HeaderMap}. + +%% @doc +%% Insert pairs into the headers, replace any values for existing keys. +%% Specifically used in response headers when setting ranges into existing +%% headers. +-spec enter_from_list([{header_key(), binary()}], headers()) -> + headers(). +enter_from_list(HeaderList, #headers{type = T, header_map = HM}) when + T == response +-> + #headers{ + type = response, + header_map = maps:merge(HM, from_list(HeaderList, false)) + }. + +%% @doc +%% Insert pairs into response headers for keys that do not already exist. +-spec default_from_list([{header_key(), binary()}], headers()) -> + headers(). +default_from_list(HeaderList, #headers{type = T, header_map = HM}) when + T == response +-> + #headers{ + type = response, + header_map = maps:merge(from_list(HeaderList, false), HM) + }. + +%% @doc +%% Add a single value for a single key to the response map +-spec enter(header_key(), binary(), headers()) -> + headers(). +enter(HeaderKey, Value, #headers{type = T, header_map = HM}) when + T == response +-> + {HK, HV} = normalize_header({HeaderKey, Value}, false), + #headers{ + type = response, + header_map = maps:put(HK, HV, HM) + }. + +%% @doc +%% Return the value of the given standard header key. `undefined` will be +%% returned for keys that are not present. +%% For non-standard (binary) keys use lookup/2. +%% If the values was a comma-separated list, or multiple headers have been +%% folded together - then a list rather than a single value is returned. +-spec get_value(standard_header_key(), headers()) -> + unicode:chardata() | list(unicode:chardata()) | undefined. +get_value(K, H) when is_atom(K) -> + case maps:get(K, H#headers.header_map, undefined) of + undefined -> + undefined; + {_OK, [V]} -> + V; + {_OK, VL} when is_list(VL) -> + VL + end. + +%% @doc +%% If multiple values may be provided for a field, but it is illegal +%% for those values to differ (e.g. in the case of content-length), only return +%% a value, if there is only one unique value. +-spec get_unique_value(standard_header_key(), headers()) -> + unicode:chardata() | undefined | {error, multiple_values}. +get_unique_value(K, H) -> + case maps:get(K, H#headers.header_map, undefined) of + undefined -> + undefined; + {_OK, [V]} -> + V; + {_OK, VL} when is_list(VL) -> + case sets:to_list(sets:from_list(VL, [{version, 2}])) of + [V] -> + V; + _ -> + {error, multiple_values} + end + end. + +%% @doc +%% some header values consist of primary information supported by secondary +%% information. The primary information is presented before a ';', and the +%% secondary information is `;` separated list +-spec parse_primary_header_value(binary()) -> unicode:chardata(). +parse_primary_header_value(HeaderValue) -> + binary:split(HeaderValue, <<";">>, [global, trim_all]). + +%% @doc +%% Fetch the {original key, values} for a binary (non-standard) header key. +%% There is a boolean flag to indicate if the key has already been subject to +%% casefold. +-spec lookup(binary_header_key(), headers(), boolean()) -> + {binary(), list(unicode:chardata())} | undefined. +lookup(CaseFoldedKey, H, true) when is_binary(CaseFoldedKey) -> + maps:get(CaseFoldedKey, H#headers.header_map, undefined); +lookup(RawKey, Headers, false) when is_binary(RawKey) -> + lookup(normalize_key(RawKey), Headers, true). + +%% @doc +%% Fetch a list of non-standard headers with a given prefix. The list is a +%% list of {K, [V]} where K is the remainder of the original key once the +%% original prefix has been stripped +-spec prefix_fold(binary_header_key(), headers(), boolean()) -> + list({unicode:chardata(), list(unicode:chardata())}). +prefix_fold(CaseFoldPrefix, Headers, true) when is_binary(CaseFoldPrefix) -> + Keys = maps:keys(Headers#headers.header_map), + filter_headers( + Keys, + CaseFoldPrefix, + byte_size(CaseFoldPrefix), + Headers#headers.header_map, + [] + ); +prefix_fold(RawPrefix, Headers, false) -> + prefix_fold(normalize_key(RawPrefix), Headers, true). + +%% @doc +%% Output a binary representing the block of response headers to be pushed to +%% the socket. Includes trailing line feed at end of last line, but not a +%% separating line feed to the response body +-spec output_response_block(headers()) -> binary(). +output_response_block(#headers{type = T, header_map = HM}) when T == response -> + HeaderList = maps:values(HM), + iolist_to_binary( + lists:map( + fun({BK, VL}) -> + << + BK/binary, + (?KV_SEPARATOR)/binary, + (join_values(VL))/binary, + (?L_SEPARATOR)/binary + >> + end, + HeaderList + ) + ). + +-define(COUNT_EXCEEDED, <<"Headers exceeded maximum count of ~w">>). +-define(SIZE_EXCEEDED, <<"Header ~s exceeded maximum size of ~w">>). + +%% @doc +%% Parse a binary block representing the start of a block of request headers, +%% with a buffer function to request more should the block be incomplete. +-spec parse_request_block( + binary(), + buffer_fun(), + {pos_integer(), pos_integer()} +) -> + {ok, headers(), binary()} | riak_api_web_acceptor:halt_response(). +parse_request_block(Buffer, BufferFun, {MaxCount, MaxSize}) -> + parse_request_block(Buffer, BufferFun, {MaxCount, MaxSize}, {[], 0}). + +parse_request_block(_B, _BFun, {MaxCount, _MS}, {_H, C}) when C > MaxCount -> + {halt, 431, [], ?COUNT_EXCEEDED, [MaxCount]}; +parse_request_block(Buffer, BufferFun, {MaxCount, MaxSize}, {HeaderAcc, C}) -> + case erlang:decode_packet(httph_bin, Buffer, []) of + {ok, {http_header, _, _, OrigKey, V}, _} when byte_size(V) > MaxSize -> + {halt, 431, [], ?SIZE_EXCEEDED, [OrigKey, MaxSize]}; + {ok, {http_header, _, Key, _OrigKey, Value}, Rest} when is_atom(Key) -> + parse_request_block( + Rest, + BufferFun, + {MaxCount, MaxSize}, + {[{Key, Value} | HeaderAcc], C + 1} + ); + {ok, {http_header, _, _Key, OrigKey, Value}, Rest} -> + parse_request_block( + Rest, + BufferFun, + {MaxCount, MaxSize}, + {[{OrigKey, Value} | HeaderAcc], C + 1} + ); + {ok, http_eoh, Rest} -> + {ok, make(HeaderAcc), Rest}; + {ok, {http_error, _}, Rest} -> + parse_request_block( + Rest, + BufferFun, + {MaxCount, MaxSize}, + {HeaderAcc, C} + ); + {more, _} -> + parse_request_block( + BufferFun(Buffer), + BufferFun, + {MaxCount, MaxSize}, + {HeaderAcc, C} + ) + end. + +%%%============================================================================ +%%% Internal Functions +%%%============================================================================ + +-spec join_values(list(unicode:chardata())) -> binary(). +-if(?OTP_RELEASE >= 28). +join_values(VL) -> + binary:join(VL, ?V_SEPARATOR). +-else. +join_values(VL) -> + iolist_to_binary(lists:join(?V_SEPARATOR, VL)). +-endif. + +-spec filter_headers( + list(header_key()), + unicode:chardata(), + pos_integer(), + header_map(), + list(header_value()) +) -> + list(header_value()). +filter_headers([], _Prefix, _PL, _HMap, Acc) -> + Acc; +filter_headers([Key | RestKeys], Prefix, PL, HMap, Acc) -> + case Key of + <> -> + {<<_Ignore:PL/binary, Suffix/binary>>, Values} = + maps:get(Key, HMap), + filter_headers(RestKeys, Prefix, PL, HMap, [{Suffix, Values} | Acc]); + _ -> + filter_headers(RestKeys, Prefix, PL, HMap, Acc) + end. + +-spec from_list([{header_key(), binary() | list(binary())}], boolean()) -> + header_map(). +from_list(HeaderList, IsReqHeader) -> + lists:foldl( + fun(Header, Acc) -> + {NK, {RK, HVL}} = normalize_header(Header, IsReqHeader), + maps:update_with( + NK, + fun({ERK, EHVL}) -> {ERK, HVL ++ EHVL} end, + {RK, HVL}, + Acc + ) + end, + maps:new(), + HeaderList + ). + +-spec normalize_header( + {header_key(), binary() | list(binary())}, boolean() +) -> + {header_key(), header_value()}. +normalize_header({KAtom, Value}, _) when is_atom(KAtom) -> + {KAtom, {atom_to_binary(KAtom), normalize_value(Value)}}; +normalize_header({KBin, Value}, true) when is_binary(KBin) -> + {string:casefold(KBin), {KBin, normalize_value(Value)}}; +normalize_header({KBin, Value}, false) when is_binary(KBin) -> + {KBin, {KBin, normalize_value(Value)}}. + +-spec normalize_key + (standard_header_key()) -> standard_header_key(); + (binary_header_key()) -> binary_header_key(). +normalize_key(KAtom) when is_atom(KAtom) -> + KAtom; +normalize_key(KBin) when is_binary(KBin) -> + string:casefold(KBin). + +-spec normalize_value(binary() | list(binary())) -> + list(binary()). +normalize_value(MultipleValues) when is_list(MultipleValues) -> + lists:filter(fun is_binary/1, MultipleValues); +normalize_value(FieldValue) when is_binary(FieldValue) -> + {CP, WS} = + persistent_term:get( + {?MODULE, compile_patterns}, + {?V_SEPARATOR, ?OWS} + ), + lists:map( + fun(V) -> + case binary:split(V, WS, [global, trim_all]) of + [V0] when is_binary(V0) -> + V0; + _ -> + string:trim(V, both) + end + end, + binary:split(FieldValue, CP, [global]) + ). + +%% @doc Call this function when initialising API +-spec compile_separators() -> ok. +compile_separators() -> + CP = binary:compile_pattern(?V_SEPARATOR), + WS = binary:compile_pattern(?OWS), + persistent_term:put({?MODULE, compile_patterns}, {CP, WS}). + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +split_perf_test() -> + HV1 = <<"SOME-INDEX|HEADER|NOTSPLIT">>, + HV2 = <<"HDR1, HDR2, HDR3">>, + L = [HV1, HV1, HV1, HV1, HV2], + FullL = lists:flatten(lists:map(fun(_I) -> L end, lists:seq(1, 1000))), + {TS1, L1} = + timer:tc( + fun() -> + lists:map( + fun(HV) -> binary:split(HV, ?V_SEPARATOR, [global]) end, + FullL + ) + end + ), + CPVS = binary:compile_pattern(?V_SEPARATOR), + {TS2, L2} = + timer:tc( + fun() -> + lists:map( + fun(HV) -> binary:split(HV, CPVS, [global]) end, + FullL + ) + end + ), + ?assertMatch(L1, L2), + io:format(user, "No-compile ~w compile ~w microseconds", [TS1, TS2]). + +parse_block_test() -> + RequestHeader1 = + << + "content-length: 1024\r\n" + "x-riak-Index-field1_bin: NAME1|DOB1, NAME2|DOB1\r\n" + "x-riak-index-Field1_bin: NAME3|DOB1\r\n" + "X-Riak-Index-field2_bin: POSTCODE1|DOB1\r\n" + >>, + RequestHeader2 = + << + "x-riak-index-field2_bin: POSTCODE2|DOB1\r\n" + "\r\n" + >>, + parse_block_tester(RequestHeader1, RequestHeader2). + +parse_splitblock_test() -> + RequestHeader1 = + << + "content-length: 1024\r\n" + "x-riak-Index-field1_bin: NAME1|DOB1, NAME2|DOB1\r\n" + "x-riak-index-Field1_bin: NAME3|DOB1\r\n" + "X-Riak-Index-field2_bin: POSTCODE1" + >>, + RequestHeader2 = + << + "|DOB1\r\nx-riak-index-field2_bin: POSTCODE2|DOB1\r\n" + "\r\n" + >>, + parse_block_tester(RequestHeader1, RequestHeader2). + +parse_block_tester(RequestHeader1, RequestHeader2) -> + BufferFun = fun(B) -> <> end, + {ok, Headers, <<>>} = + parse_request_block(RequestHeader1, BufferFun, {1024, 2048}), + ?assertMatch( + <<"1024">>, + get_value('Content-Length', Headers) + ), + ?assertMatch( + { + <<"x-riak-index-Field1_bin">>, + [<<"NAME1|DOB1">>, <<"NAME2|DOB1">>, <<"NAME3|DOB1">>] + }, + lookup(<<"x-riak-index-field1_bin">>, Headers, true) + ), + ?assertMatch( + { + <<"x-riak-index-field2_bin">>, + [<<"POSTCODE1|DOB1">>, <<"POSTCODE2|DOB1">>] + }, + lookup(<<"x-riak-index-Field2_bin">>, Headers, false) + ), + ?assertMatch( + <<"1024">>, + get_unique_value('Content-Length', Headers) + ). + +riak_metadata_test() -> + RequestHeader1 = + << + "content-length: 1024\r\n" + "x-riak-Index-field1_bin: NAME1|DOB1, NAME2|DOB1\r\n" + "x-riak-index-Field1_bin: NAME3|DOB1 \r\n" + "X-Riak-Index-field2_bin: POSTCODE1|DOB1\r\n" + >>, + RequestHeader2 = + << + "x-riak-index-field2_bin: POSTCODE2|DOB1\r\n" + "x-riak-meta-key1: METAVALUE1\r\n" + "x-riak-meta-key2: METAVALUE2\r\n" + "\r\n" + >>, + BufferFun = fun(B) -> <> end, + {ok, Headers, <<>>} = + parse_request_block(RequestHeader1, BufferFun, {1024, 2048}), + IndexList = prefix_fold(<<"x-riak-index-">>, Headers, true), + ?assertMatch( + { + <<"Field1_bin">>, + [<<"NAME1|DOB1">>, <<"NAME2|DOB1">>, <<"NAME3|DOB1">>] + }, + lists:keyfind(<<"Field1_bin">>, 1, IndexList) + ), + MetaList = prefix_fold(<<"X-Riak-Meta-">>, Headers, false), + ?assertMatch( + {<<"key1">>, [<<"METAVALUE1">>]}, + lists:keyfind(<<"key1">>, 1, MetaList) + ). + +content_smuggling_test() -> + RequestHeader1 = + << + "content-length: 1024\r\n" + "x-riak-Index-field1_bin: NAME1|DOB1, NAME2|DOB1\r\n" + "x-riak-index-Field1_bin: NAME3|DOB1 \t \r\n" + "X-Riak-Index-field2_bin: POSTCODE1|DOB1\r\n" + "content-length: 16384\r\n" + "\r\n" + >>, + {ok, Headers, <<>>} = + parse_request_block(RequestHeader1, fun() -> <<>> end, {1024, 2048}), + ?assertMatch( + {error, multiple_values}, + get_unique_value('Content-Length', Headers) + ). + +response_header_test() -> + InitHeaders = + [ + {'Server', <<"Riak Web API">>}, + {'Content-Length', <<"1024">>}, + {'Etag', <<"sometag">>}, + {<<"X-Riak-Index-field1_bin">>, [ + <<"NAME1|DOB1">>, <<"NAME2|DOB1">> + ]}, + {<<"X-Riak-Index-field1_bin">>, <<"NAME3|DOB1">>}, + {<<"X-Riak-Index-field2_bin">>, <<"POSTCODE1|DOB1">>} + ], + RespHeaders1 = make_rsp_header(InitHeaders), + DefaultList = + [ + {'Server', <<"Riak Web API 1.0">>}, + {'Date', <<"Mon, 15 Apr 2025 10:06:15 GMT">>} + ], + RespHeaders2 = default_from_list(DefaultList, RespHeaders1), + EntryList = + [ + {'Etag', <<"some_md5_tag">>}, + {'Vary', <<"*">>} + ], + RespHeaders3 = enter_from_list(EntryList, RespHeaders2), + Response = output_response_block(RespHeaders3), + ExpectedResponse = + << + "Date: Mon, 15 Apr 2025 10:06:15 GMT\r\n" + "Server: Riak Web API\r\n" + "Vary: *\r\n" + "Content-Length: 1024\r\n" + "Etag: some_md5_tag\r\n" + "X-Riak-Index-field1_bin: NAME3|DOB1, NAME1|DOB1, NAME2|DOB1\r\n" + "X-Riak-Index-field2_bin: POSTCODE1|DOB1\r\n" + >>, + ?assertMatch(ExpectedResponse, Response). + +-endif. diff --git a/src/riak_api_web_security.erl b/src/riak_api_web_security.erl index e719df1..97c312c 100644 --- a/src/riak_api_web_security.erl +++ b/src/riak_api_web_security.erl @@ -1,39 +1,184 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2007-2009 Basho Technologies +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% %% @doc Some security helper functions for Riak API endpoints + -module(riak_api_web_security). +-include_lib("kernel/include/logger.hrl"). + +-export([is_authorised/4]). + +-define(AUTH_PREFIX, "Basic "). +-define(TXT_HEADER, {'Content-Type', <<"text/plain">>}). --export([is_authorized/1]). - -%% @doc Check if the user is authorized --spec is_authorized(any()) -> {true, any()} | false | insecure. -is_authorized(ReqData) -> - case riak_core_security:is_enabled() of - true -> - Scheme = wrq:scheme(ReqData), - case Scheme == https of - true -> - case wrq:get_req_header("Authorization", ReqData) of - "Basic " ++ Base64 -> - UserPass = base64:decode_to_string(Base64), - [User, Pass] = [list_to_binary(X) || X <- - string:tokens(UserPass, ":")], - {ok, Peer} = inet_parse:address(wrq:peer(ReqData)), - case riak_core_security:authenticate(User, Pass, - [{ip, Peer}]) - of - {ok, Sec} -> - {true, Sec}; - {error, _} -> - false - end; - _ -> - false - end; - false -> - %% security is enabled, but they're connecting over HTTP. - %% which means if they authed, the credentials would be in - %% plaintext - insecure +-spec is_authorised( + boolean(), + http | https, + riak_api_web_headers:headers(), + inet:ip_address() +) -> + {ok, riak_core_security:context() | undefined} + | riak_api_web_acceptor:halt_response(). +is_authorised(Enabled, Scheme, ReqHeaders, Peer) -> + is_authorised( + Enabled, + Scheme, + ReqHeaders, + Peer, + fun(User, Pass, Pip) -> + riak_core_security:authenticate(User, Pass, [{ip, Pip}]) + end + ). + +is_authorised(true, https, ReqHeaders, Peer, AuthFun) -> + case riak_api_web_headers:get_unique_value('Authorization', ReqHeaders) of + <> -> + try + UserPass = base64:decode(Base64UP), + [User, Pass] = string:lexemes(UserPass, ":"), + case AuthFun(User, Pass, Peer) of + {ok, SecContext} -> + {ok, SecContext}; + {error, Error} -> + {halt, 401, [?TXT_HEADER], <<"~0p">>, [Error]} + end + catch + _:ExError -> + error_decoding_credentials(ExError) end; - false -> - {true, undefined} %% no security context - end. + undefined -> + {halt, 401, [?TXT_HEADER], <<"No credentials provided">>, []}; + Unexpected -> + error_decoding_credentials(Unexpected) + end; +is_authorised(true, http, _ReqHeaders, _Peer, _AuthFun) -> + {halt, 426, [?TXT_HEADER], <<"Upgrade required to https">>, []}; +is_authorised(false, _, _ReqHeaders, _Peer, _AuthFun) -> + {ok, undefined}. + +error_decoding_credentials(ErrorTerm) -> + ?LOG_WARNING("Error decoding credentials ~0p", [ErrorTerm]), + {halt, 400, [?TXT_HEADER], <<"Error decoding credentials">>, []}. + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +simple_security_test() -> + User1 = <<"User1">>, + User2 = <<"User2">>, + User3 = <<"User3">>, + Pass1 = <<"Pass1!">>, + Pass2 = <<"Pass2!">>, + Pass3 = <<"Pass3!">>, + AuthMap = #{User1 => Pass1, User2 => Pass2, User3 => Pass3}, + AuthFun = + fun(User, Pass, _IgnorePeer) when is_binary(Pass) -> + case maps:get(User, AuthMap, undefined) of + Pass -> + {ok, ok}; + _ -> + {error, invalid_credentials} + end + end, + Combo1 = base64:encode(iolist_to_binary([User1, <<":">>, Pass1])), + ?assertMatch( + {ok, ok}, + is_authorised( + true, + https, + make_request_headers(Combo1), + {ip, {127, 0, 0, 1}}, + AuthFun + ) + ), + ?assertMatch( + {halt, 400, [?TXT_HEADER], <<"Error decoding credentials">>, []}, + is_authorised( + true, + https, + make_request_headers(iolist_to_binary([Combo1, <<"A">>])), + {ip, {127, 0, 0, 1}}, + AuthFun + ) + ), + BadCombo = base64:encode(iolist_to_binary([User2, <<":">>, Pass1])), + ?assertMatch( + {halt, 401, [?TXT_HEADER], <<"~0p">>, [invalid_credentials]}, + is_authorised( + true, + https, + make_request_headers(BadCombo), + {ip, {127, 0, 0, 1}}, + AuthFun + ) + ), + Combo2 = base64:encode(iolist_to_binary([User2, <<":">>, Pass2])), + MultipleHeaders = + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>}, + {<<"X-Riak-VClock">>, <<"ABC123==">>}, + {'Authorization', iolist_to_binary([<<"Basic ">>, Combo1])}, + {'Authorization', iolist_to_binary([<<"Basic ">>, Combo2])} + ] + ), + ?assertMatch( + {halt, 400, [?TXT_HEADER], <<"Error decoding credentials">>, []}, + is_authorised( + true, + https, + MultipleHeaders, + {ip, {127, 0, 0, 1}}, + AuthFun + ) + ), + NoAuthHeaders = + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>}, + {<<"X-Riak-VClock">>, <<"ABC123==">>} + ] + ), + ?assertMatch( + {halt, 401, [?TXT_HEADER], <<"No credentials provided">>, []}, + is_authorised( + true, + https, + NoAuthHeaders, + {ip, {127, 0, 0, 1}}, + AuthFun + ) + ). + +make_request_headers(Combo) -> + riak_api_web_headers:make( + [ + {'Content-Length', <<"1024">>}, + {<<"X-Riak-VClock">>, <<"ABC123==">>}, + {'Authorization', iolist_to_binary([<<"Basic ">>, Combo])} + ] + ). + +-endif. diff --git a/src/riak_api_web_socket.erl b/src/riak_api_web_socket.erl new file mode 100644 index 0000000..8390910 --- /dev/null +++ b/src/riak_api_web_socket.erl @@ -0,0 +1,556 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2007 Mochi Media, Inc +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Socket and acceptor pool management for web requests +%% +%% Socket manager intended to abstract away from choice of SSL, and also +%% maintain a pool of accept processes that are ready to accept new connection +%% requests +%% +%% Each acceptor is an `riak_api_web_acceptor` - an as each acceptor accepts +%% a connection, it will prompt this socket server to launch a new acceptor. +%% When a linked acceptor closes (along with the connection), the close message +%% is handled and the closed acceptor is removed from the pool. +%% +%% The intention is that there should always be at least the pool size of +%% acceptors waiting for a connection - unless the max size is reached, and no +%% new acceptors will be started. This means that concurrently no more +%% connections can be handled concurrently than the max pool size. +%% +%% The module was initially based on the: +%% - mochiweb_socket_server +%% - mochiweb_socket +%% - mochiweb_acceptor +%% +%% Patterns used in these modules have been compared with the Elli web server +%% for validation - https://github.com/elli-lib/elli. + +-module(riak_api_web_socket). + +-if(?OTP_RELEASE == 26). +-feature(maybe_expr, enable). +-endif. + +-behaviour(gen_server). + +-export( + [ + start_link/1, + get_max_pool_size/1, + set_max_pool_size/2, + get_active_pool_size/1 + ] +). + +-export( + [ + init/1, + handle_call/3, + handle_cast/2, + handle_info/2 + ] +). + +-export( + [ + get_scheme/1, + accept/2, + recv/3, + recv_line/2, + send/2, + close/1, + stop/1, + get_peer/1, + acceptor_accepted/1 + ] +). + +-include_lib("kernel/include/logger.hrl"). + +-define(POOL_SIZE_DEFAULT, 16). +-define(POOL_SIZE_MAX_DEFAULT, 2048). +-define(DEFAULT_RECV_BUFFER, 131072). +% Setting the receive buffer will also change the buffer +% https://github.com/erlang/otp/issues/9355 + +-record(socket_state, { + port :: inet:port_number(), + listener :: socket(), + pool_size = ?POOL_SIZE_DEFAULT :: pos_integer(), + max_pool_size = ?POOL_SIZE_MAX_DEFAULT :: pos_integer(), + acceptor_pool = sets:new([{version, 2}]) :: sets:set() +}). + +-type socket_option() :: + {ip, inet:ip_address()} + | binary + | {reuseaddr, boolean()} + %% Assumed necessary to allow for rapid restart of supervised + %% process - e.g. allow for next process to listen on socket even + %% when the previous process has not completed the close + | {packet, raw} + | {active, boolean()} + %% After a connection is accepted the socket is manually read to be + %% decoded + | {backlog, pos_integer()} +%% If this is too low it may result in some requests being reset when +%% there is a burst of new connections +. + +-type buffer_option() :: + {recbuf, pos_integer()} + | {sndbuf, pos_integer()} + | {buffer, pos_integer()} +% The size of the user-level buffer used by the driver. +% Not to be confused with options sndbuf and recbuf, which correspond +% to the Kernel socket buffers. For TCP it is recommended to have +% val(buffer) >= val(recbuf) to avoid performance issues because +% of unnecessary copying +. + +-type server_name() :: binary(). +% Name of the root part of the address i.e. +% <<"Protocol://Host:Port">> + +-type option() :: + {acceptor_pool_start_size, pos_integer()} + % The number of acceptors to be ready to accept an new connection. + % This pool size is not a limit, it is is the starting size. As an + % acceptor picks up a new connection request it will prompt for a new + % acceptor to be spawned (and will not return to the pool once it is + % complete). + | {acceptor_pool_max_size, pos_integer()} + % The maximum number of acceptors in the pool - the total number of + % concurrent requests that can be supported on this port + | {ssl, boolean()} + | {ssl_opts, [ssl:tls_server_option()]} + | {ip, inet:ip_address()} + | {port, inet:port_number()} + | {name, server_name()}. + +-type scheme() :: http | https. + +-type web_options() :: list(option()). + +-type socket() :: {http, gen_tcp:socket()} | {https, ssl:sslsocket()}. + +-type tcp_error() :: closed | timeout | system_limit | inet:posix(). +-type tls_error() :: term(). + +-export_type([socket/0, scheme/0]). + +%%%============================================================================ +%%% API +%%%============================================================================ + +-spec start_link(web_options()) -> {ok, pid()}. +start_link(Options) -> + ServerName = + case lists:keyfind(name, 1, Options) of + {name, Name} when is_binary(Name) -> + {local, binary_to_atom(Name)} + end, + {ok, Pid} = gen_server:start_link(ServerName, ?MODULE, Options, []), + {ok, Pid}. + +-spec get_max_pool_size(server_name()) -> pos_integer(). +get_max_pool_size(ServerName) -> + gen_server:call( + binary_to_existing_atom(ServerName), + get_max_pool_size, + infinity + ). + +-spec get_active_pool_size(server_name()) -> pos_integer(). +get_active_pool_size(ServerName) -> + gen_server:call( + binary_to_existing_atom(ServerName), + get_active_pool_size, + infinity + ). + +-spec set_max_pool_size(server_name(), pos_integer()) -> ok. +set_max_pool_size(ServerName, MaxPoolSize) when is_integer(MaxPoolSize) -> + gen_server:cast( + binary_to_existing_atom(ServerName), + {set_max_pool_size, MaxPoolSize} + ). + +-spec acceptor_accepted(pid()) -> ok. +acceptor_accepted(Pid) -> + gen_server:cast(Pid, accepted). + +-spec stop(server_name()) -> ok. +stop(ServerName) -> + gen_server:call( + binary_to_existing_atom(ServerName), + stop, + infinity + ). + +%%%============================================================================ +%%% gen_server callbacks +%%%============================================================================ + +init(Options) -> + process_flag(trap_exit, true), + BufferOpts = + case get_tcp_buffer_options() of + [] -> + [{recbuf, ?DEFAULT_RECV_BUFFER}]; + NonDefaultOpts -> + ?LOG_INFO( + "Non-default TCP buffer options configured for web ~0p", + [NonDefaultOpts] + ), + NonDefaultOpts + end, + {ip, IP} = + case lists:keyfind(ip, 1, Options) of + {ip, IPString} when is_list(IPString) -> + {ok, IPAddr} = inet:parse_address(IPString), + {ip, IPAddr}; + {ip, IPAddr} -> + {ip, IPAddr} + end, + {port, Port} = lists:keyfind(port, 1, Options), + {Protocol, SSLOpts} = + case lists:keyfind(ssl, 1, Options) of + {ssl, true} -> + {ssl_opts, SSLOptsIn} = lists:keyfind(ssl_opts, 1, Options), + {https, SSLOptsIn}; + _ -> + {http, none} + end, + SocketOpts = default_socket_options(IP), + {ok, Listener} = listen(Protocol, Port, SocketOpts, BufferOpts, SSLOpts), + {AcceptorPool, StartSize, MaxSize} = + get_acceptor_pool(Listener, Port, Options), + ?LOG_INFO( + "Acceptor pool for web started on IP ~0p port ~w of size ~w", + [IP, Port, StartSize] + ), + riak_api_web:cache_today(), + riak_api_web_headers:compile_separators(), + riak_api_web_acceptor:compile_detectors(), + { + ok, + #socket_state{ + listener = Listener, + port = Port, + pool_size = StartSize, + max_pool_size = MaxSize, + acceptor_pool = sets:from_list(AcceptorPool, [{version, 2}]) + } + }. + +handle_call(get_max_pool_size, _From, State) -> + {reply, State#socket_state.max_pool_size, State}; +handle_call(get_active_pool_size, _From, State) -> + {reply, sets:size(State#socket_state.acceptor_pool), State}; +handle_call(stop, _From, State) -> + {stop, normal, ok, State}. + +handle_cast({set_max_pool_size, MPS}, State) -> + case State#socket_state.pool_size of + PS when PS =< MPS -> + {noreply, State#socket_state{max_pool_size = MPS}}; + PS -> + ?LOG_WARNING( + "Ignoring change to max pool size ~w to smaller value than " + "starting pool ~w", + [MPS, PS] + ), + {noreply, State} + end; +handle_cast(accepted, State) -> + case State#socket_state.pool_size of + PS when PS < State#socket_state.max_pool_size -> + P = + riak_api_web_acceptor:start_link( + State#socket_state.listener, + State#socket_state.port + ), + { + noreply, + State#socket_state{ + acceptor_pool = + sets:add_element(P, State#socket_state.acceptor_pool), + pool_size = PS + 1 + } + }; + _ -> + ?LOG_WARNING( + "Web connection pool reached limit of ~w", + [State#socket_state.pool_size] + ), + {noreply, State} + end. + +handle_info({'EXIT', Pid, normal}, State) -> + { + noreply, + State#socket_state{ + pool_size = State#socket_state.pool_size - 1, + acceptor_pool = + sets:del_element(Pid, State#socket_state.acceptor_pool) + } + }; +handle_info({'EXIT', Pid, Reason}, State) -> + ?LOG_ERROR("Acceptor ~p unexpectedly crashed: ~0p", [Pid, Reason]), + handle_info({'EXIT', Pid, normal}, State). + +%%%============================================================================ +%%% Internal Functions +%%%============================================================================ + +-spec default_socket_options(inet:ip_address()) -> [socket_option()]. +default_socket_options(IPAddr) -> + [ + {ip, IPAddr}, + binary, + {reuseaddr, true}, + {packet, raw}, + {active, false}, + {backlog, 128} + ]. + +-spec get_acceptor_pool(socket(), inet:port_number(), list(option())) -> + {list(pid()), pos_integer(), pos_integer()}. +get_acceptor_pool(Listener, Port, Options) -> + StartSize = + case lists:keyfind(acceptor_pool_start_size, 1, Options) of + {acceptor_pool_start_size, SS} when is_integer(SS), SS > 0 -> + SS; + false -> + application:get_env( + riak_api, + web_acceptor_pool_start_size, + ?POOL_SIZE_DEFAULT + ) + end, + MaxSize = + case lists:keyfind(acceptor_pool_max_size, 1, Options) of + {acceptor_pool_max_size, MS} when is_integer(MS), MS > 0 -> + MS; + false -> + application:get_env( + riak_api, + web_acceptor_pool_max_size, + ?POOL_SIZE_MAX_DEFAULT + ) + end, + case {StartSize, MaxSize} of + {StartSize, MaxSize} when + is_integer(StartSize), + is_integer(MaxSize), + MaxSize >= StartSize + -> + { + start_acceptor_pool(Listener, Port, StartSize), + StartSize, + MaxSize + }; + InvalidConfig -> + ?LOG_ERROR( + "Invalid configuration of acceptor pool ~0p - " + "starting with defaults", + [InvalidConfig] + ), + { + start_acceptor_pool(Listener, Port, ?POOL_SIZE_DEFAULT), + ?POOL_SIZE_DEFAULT, + ?POOL_SIZE_MAX_DEFAULT + } + end. + +-spec start_acceptor_pool( + socket(), + inet:port_number(), + pos_integer() +) -> + list(pid()). +start_acceptor_pool(Listener, Port, Size) -> + lists:map( + fun(_I) -> + P = riak_api_web_acceptor:start_link(Listener, Port), + true = is_pid(P), + P + end, + lists:seq(1, Size) + ). + +-spec get_tcp_buffer_options() -> list(buffer_option()). +get_tcp_buffer_options() -> + get_tcp_buffer_options( + [ + {buffer, web_kernel_buffer}, + {recbuf, web_receive_buffer}, + {sndbuf, web_send_buffer} + ], + [] + ). + +get_tcp_buffer_options([], BufferOptions) -> + BufferOptions; +get_tcp_buffer_options([{Name, EnVar} | Rest], BufferOptions) -> + case application:get_env(riak_api, EnVar) of + {ok, BSize} when is_integer(BSize) -> + get_tcp_buffer_options(Rest, [{Name, BSize} | BufferOptions]); + _ -> + get_tcp_buffer_options(Rest, BufferOptions) + end. + +-spec get_scheme(socket()) -> scheme(). +get_scheme({Scheme, _Socket}) -> + Scheme. + +-spec listen( + scheme(), + inet:port_number(), + list(socket_option()), + list(buffer_option()), + none | list(ssl:tls_server_option()) +) -> + {ok, socket()} | {error, any()}. +listen(http, Port, SocketOpts, BufferOpts, none) -> + case gen_tcp:listen(Port, SocketOpts ++ BufferOpts) of + {ok, Socket} -> + {ok, {http, Socket}}; + {error, Reason} -> + {error, Reason} + end; +listen(https, Port, SocketOpts, BufferOpts, SSLOpts) when SSLOpts =/= none -> + case ssl:listen(Port, SocketOpts ++ BufferOpts ++ SSLOpts) of + {ok, Socket} -> + {ok, {https, Socket}}; + {error, Reason} -> + {error, Reason} + end. + +-spec accept( + socket(), + pos_integer() +) -> + {ok, socket()} | {error, tcp_error() | tls_error()}. +accept({http, Socket}, Timeout) -> + case gen_tcp:accept(Socket, Timeout) of + {ok, S} -> + {ok, {http, S}}; + {error, Reason} -> + {error, Reason} + end; +accept({https, Socket}, Timeout) -> + case ssl:transport_accept(Socket, Timeout) of + {ok, S} -> + case ssl:handshake(S, Timeout) of + {ok, S1} -> + {ok, {https, S1}}; + {error, Reason} -> + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end. + +-spec recv( + socket(), + non_neg_integer(), + non_neg_integer() | infinity +) -> + {ok, binary()} | {error, any()}. +recv({http, Socket}, Size, Timeout) -> + case gen_tcp:recv(Socket, Size, Timeout) of + {ok, Data} when is_binary(Data) -> + {ok, Data}; + {error, Error} -> + {error, Error} + end; +recv({https, Socket}, Size, Timeout) -> + case ssl:recv(Socket, Size, Timeout) of + {ok, Data} when is_binary(Data) -> + {ok, Data}; + {error, Error} -> + {error, Error} + end. + +-spec recv_line( + socket(), + non_neg_integer() | infinity +) -> + {ok, binary()} | {error, any()}. +recv_line({http, Socket}, Timeout) -> + maybe + ok ?= inet:setopts(Socket, [{packet, line}]), + {ok, Data} ?= gen_tcp:recv(Socket, 0, Timeout), + ok ?= inet:setopts(Socket, [{packet, raw}]), + true = is_binary(Data), + {ok, Data} + else + {error, Error} -> + {error, Error} + end; +recv_line({https, Socket}, Timeout) -> + maybe + ok ?= ssl:setopts(Socket, [{packet, line}]), + {ok, Data} ?= ssl:recv(Socket, 0, Timeout), + ok ?= ssl:setopts(Socket, [{packet, raw}]), + true = is_binary(Data), + {ok, Data} + else + {error, Error} -> + {error, Error} + end. + +-spec send(socket(), binary()) -> ok | {error, any()}. +send({http, Socket}, Data) -> + gen_tcp:send(Socket, Data); +send({https, Socket}, Data) -> + ssl:send(Socket, Data). + +-spec close(socket()) -> ok | {error, any()}. +close({http, Socket}) -> + gen_tcp:close(Socket); +close({https, Socket}) -> + ssl:close(Socket). + +-spec get_peer( + socket() +) -> + {ok, inet:ip_address(), public_key:cert() | undefined} | {error, any()}. +get_peer({http, Socket}) -> + case inet:peername(Socket) of + {ok, {Addr, _Port}} when is_tuple(Addr) -> + {ok, Addr, undefined}; + {error, Error} -> + {error, Error} + end; +get_peer({https, Socket}) -> + case ssl:peername(Socket) of + {ok, {Addr, _Port}} when is_tuple(Addr) -> + case ssl:peercert(Socket) of + {ok, Cert} -> + {ok, Addr, Cert}; + _ -> + {ok, Addr, undefined} + end; + {error, Error} -> + {error, Error} + end. diff --git a/src/riak_api_wm_urlmap.erl b/src/riak_api_wm_urlmap.erl deleted file mode 100644 index 02767a8..0000000 --- a/src/riak_api_wm_urlmap.erl +++ /dev/null @@ -1,81 +0,0 @@ -%% ------------------------------------------------------------------- -%% -%% riak_api_wm_urlmap: expose the roots of registered Webmachine resources -%% -%% Copyright (c) 2007-2013 Basho Technologies, Inc. All Rights Reserved. -%% -%% This file is provided to you under the Apache License, -%% Version 2.0 (the "License"); you may not use this file -%% except in compliance with the License. You may obtain -%% a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, -%% software distributed under the License is distributed on an -%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -%% KIND, either express or implied. See the License for the -%% specific language governing permissions and limitations -%% under the License. -%% -%% ------------------------------------------------------------------- - -%% @doc This module provides a Webmachine resource that lists the -%% URLs for other resources available on this host. -%% -%% Links to Riak resources will be added to the Link header in -%% the form: -%%``` -%% ; rel="RESOURCE_NAME" -%%''' -%% HTML output of this resource is a list of link tags like: -%%``` -%% RESOURCE_NAME -%%''' -%% JSON output of this resource in an object with elements like: -%%``` -%% "RESOURCE_NAME":"URL" -%%''' --module(riak_api_wm_urlmap). --export([ - init/1, - resource_exists/2, - content_types_provided/2, - to_html/2, - to_json/2 - ]). - --include_lib("webmachine/include/webmachine.hrl"). - -init([]) -> - {ok, service_list()}. - -resource_exists(RD, Services) -> - {true, add_link_header(RD, Services), Services}. - -add_link_header(RD, Services) -> - wrq:set_resp_header( - "Link", - string:join([ ["<",Uri,">; rel=\"",Resource,"\""] - || {Resource, Uri} <- Services ], - ","), - RD). - -content_types_provided(RD, Services) -> - {[{"text/html", to_html},{"application/json", to_json}], RD, Services}. - -to_html(RD, Services) -> - {[""], - RD, Services}. - -to_json(RD, Services) -> - {mochijson:encode({struct, Services}), RD, Services}. - -service_list() -> - Dispatch = webmachine_router:get_routes(), - lists:usort( - [{atom_to_list(Resource), "/"++UriBase} - || {[UriBase|_], Resource, _} <- Dispatch]). diff --git a/test/riak_api_web_ets_store.erl b/test/riak_api_web_ets_store.erl new file mode 100644 index 0000000..cad4df8 --- /dev/null +++ b/test/riak_api_web_ets_store.erl @@ -0,0 +1,565 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Test handler that responds with random data. + +-module(riak_api_web_ets_store). + +-if(?OTP_RELEASE == 26). +-feature(maybe_expr, enable). +-endif. + +-behaviour(riak_api_web_handler). + +-export( + [ + match_route/3, + check_permissions/5, + parse_query_params/2, + parse_request_headers/2, + process_request/2, + record_request/3, + slice_stream_fun/1 + ] +). + +-ifdef(TEST). +-export( + [ + setup/0, + generator/1, + cleanup/1 + ] +). +-endif. + +-define(SLICE_SIZE, 10 * 1024). + +-record(context, { + key :: unicode:chardata(), + method :: 'GET' | 'PUT', + type :: object | file, + slice_list = [] :: list({range(), guid()}), + last_slice_end = 0 :: non_neg_integer() +}). + +-type context() :: #context{}. +-type guid() :: binary(). +-type range() :: {non_neg_integer(), non_neg_integer()}. + +%% @doc match_route for the module +-spec match_route( + riak_api_web_acceptor:method(), + unicode:chardata(), + list(unicode:chardata()) +) -> + nomatch + | {method_not_allowed, list(riak_api_web_acceptor:method())} + | {ok, riak_api_web_handler:limits(), context()}. +match_route(Method, _P, [<<"ets_object">>, <<"key">>, Key]) when + Method == 'GET'; Method == 'PUT' +-> + { + ok, + {10, 1024, 16 * 1024}, + #context{key = Key, method = Method, type = object} + }; +match_route(_, _, [<<"ets_object">>, <<"key">>, _Key]) -> + {method_not_allowed, ['GET', 'PUT']}; +match_route(Method, _P, [<<"ets_file">>, <<"filename">>, Key]) when + Method == 'GET'; Method == 'PUT' +-> + { + ok, + {10, 1024, 1024 * 1024}, + #context{key = Key, method = Method, type = object} + }; +match_route(_, _, _) -> + nomatch. + +%% @doc check_permissions for using this module or route +-spec check_permissions( + riak_api_web_headers:headers(), + riak_api_web_socket:scheme(), + riak_api_web_handler:peer_ip(), + public_key:cert() | undefined, + context() +) -> + {ok, context()}. +check_permissions(_Hdrs, _Scheme, _Peer, _Cert, Ctx) -> + {ok, Ctx}. + +%% @doc parse and validate query params, passed as a map +-spec parse_query_params( + riak_api_web_handler:query_params(), + context() +) -> + {ok, context()} | riak_api_web_acceptor:halt_response(). +parse_query_params(_Params, Ctx) -> + {ok, Ctx}. + +%% @doc parse and validate the request headers +-spec parse_request_headers( + riak_api_web_headers:headers(), + context() +) -> + {ok, context()} | riak_api_web_acceptor:halt_response(). +parse_request_headers(_ReqHeaders, Ctx) -> + {ok, Ctx}. + +%% @doc Process the request and produce a response +-spec process_request( + riak_api_web_body:req_body(), + context() +) -> + { + ok, + { + riak_api_web_acceptor:response_code(), + riak_api_web_headers:header_list(), + riak_api_web_handler:response_body(), + boolean(), + riak_api_web_body:req_body() + }, + context() + }. +process_request( + RqBdy, Ctx = #context{key = Key, method = 'GET', type = object} +) -> + case ets:lookup(?MODULE, {object, Key}) of + [{{object, Key}, Value}] -> + {ok, {200, [], Value, true, RqBdy}, Ctx}; + [] -> + {ok, {404, [], <<>>, true, RqBdy}, Ctx} + end; +process_request( + RqBdy, Ctx = #context{key = Key, method = 'PUT', type = object} +) -> + case riak_api_web_body:get_body(RqBdy, all, 10000) of + {Value, UpdRqBdy} when is_binary(Value) -> + ets:insert(?MODULE, {{object, Key}, Value}), + ETag = base64:encode(crypto:hash(md5, Value), #{mode => urlsafe}), + {ok, {204, [{'Etag', ETag}], <<>>, true, UpdRqBdy}, Ctx}; + {error, content_too_large} -> + {ok, {413, [], <<>>, false, RqBdy}, Ctx} + end; +process_request( + RqBdy, Ctx = #context{key = Key, method = 'GET', type = file} +) -> + case ets:lookup(?MODULE, {file, Key}) of + [{{file, Key}, SliceList}] -> + { + ok, + { + 200, + [], + {stream, slice_stream_fun(lists:sort(SliceList))}, + true, + RqBdy + }, + Ctx + }; + [] -> + {ok, {404, [], <<>>, true, RqBdy}, Ctx} + end; +process_request( + RqBdy, Ctx = #context{key = Key, method = 'PUT', type = file} +) -> + case riak_api_web_body:get_body(RqBdy, ?SLICE_SIZE, 10000) of + {Slice, UpdRqBdy} when is_binary(Slice) -> + SliceKey = generate_uuid(), + SliceSize = byte_size(Slice), + ets:insert_new(?MODULE, {{slice, SliceKey}, Slice}), + process_request( + UpdRqBdy, + Ctx#context{ + slice_list = + [ + { + {Ctx#context.last_slice_end, SliceSize}, + SliceKey + } + | Ctx#context.slice_list + ], + last_slice_end = Ctx#context.last_slice_end + SliceSize + } + ); + {done, UpdRqBdy} -> + ets:insert(?MODULE, {{file, Key}, Ctx#context.slice_list}), + ETag = + base64:encode( + crypto:hash(md5, term_to_binary(Ctx#context.slice_list)), + #{mode => urlsafe} + ), + {ok, {204, [{'Etag', ETag}], <<>>, true, UpdRqBdy}, Ctx}; + {error, content_too_large} -> + {ok, {413, [], <<>>, false, RqBdy}, Ctx} + end. + +generate_uuid() -> + <> = crypto:strong_rand_bytes(16), + L = io_lib:format( + "~8.16.0b-~4.16.0b-4~3.16.0b-~4.16.0b-~12.16.0b", + [A, B, C band 16#0fff, D band 16#3fff bor 16#8000, E] + ), + list_to_binary(L). + +slice_stream_fun([]) -> + fun() -> done end; +slice_stream_fun(List) -> + fun() -> + [{_Range, SliceKey} | Rest] = List, + [{{slice, SliceKey}, Slice}] = ets:lookup(?MODULE, {slice, SliceKey}), + {Slice, slice_stream_fun(Rest)} + end. + +%% @doc Record the output of the interaction +-spec record_request( + riak_api_web_handler:timings(), + riak_api_web_handler:completion(), + context() +) -> + ok. +record_request(Timings, Completion, Ctx) -> + {A, B, C} = Timings, + io:format( + user, + "Request ~w ~w with timings ~0p~n", + [Ctx#context.method, Completion, {B - A, C - B, C - A}] + ). + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). +-include_lib("stdlib/include/assert.hrl"). + +basic_handler_test_() -> + {setup, fun setup/0, fun cleanup/1, fun generator/1}. + +setup() -> + inets:start(), + TestPort = find_available_port(lists:seq(8000, 8999)), + IPAddr = {127, 0, 0, 1}, + SpecName = riak_api_web:spec_name(http, IPAddr, TestPort), + Options = + [ + {name, SpecName}, + {ip, IPAddr}, + {port, TestPort}, + {acceptor_pool_start_size, 4} + ], + {ok, _Pid} = riak_api_web_socket:start_link(Options), + riak_api_web:add_routes([{20, ?MODULE}]), + ets:new( + ?MODULE, + [named_table, public, {read_concurrency, true}] + ), + {ok, _HTTPC} = inets:start(httpc, [{profile, test_client}]), + ok = httpc:set_options([{verbose, false}], test_client), + {SpecName, IPAddr, TestPort}. + +generator({_SpecName, IPAddr, Port}) -> + [ + put_then_get(IPAddr, Port), + put_too_big(IPAddr, Port), + put_big_header(IPAddr, Port), + put_then_get_file(IPAddr, Port), + raw_put_then_get_file(IPAddr, Port), + raw_put_toobig_object(IPAddr, Port) + ]. + +cleanup({SpecName, _IPAddr, _Port}) -> + ok = inets:stop(), + ?assertMatch(4, riak_api_web_socket:get_active_pool_size(SpecName)), + riak_api_web_socket:stop(SpecName), + ok. + +raw_put_toobig_object({A, B, C, D}, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + {A, B, C, D}, + Port, + [binary, {packet, raw}, {active, false}] + ), + RequestHead = + << + "PUT /ets_object/key/K0006 HTTP/1.1\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + >>, + gen_tcp:send(Socket, RequestHead), + _Hash = send_chunked_4KBobject(Socket), + ok = inet:setopts(Socket, [{packet, line}]), + {ok, L1} = gen_tcp:recv(Socket, 0, 10000), + ?assertMatch( + <<"HTTP/1.1 413 Content Too Large\r\n">>, + L1 + ), + ok = inet:setopts(Socket, [{packet, raw}]), + {ok, _RspHdrs} = gen_tcp:recv(Socket, 0, 10000), + ok = gen_tcp:close(Socket) + end. + +raw_put_then_get_file({A, B, C, D}, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + {A, B, C, D}, + Port, + [binary, {packet, raw}, {active, false}] + ), + RequestHead = + << + "PUT /ets_file/filename/K0005 HTTP/1.1\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + >>, + gen_tcp:send(Socket, RequestHead), + Hash = send_chunked_4KBobject(Socket), + ok = inet:setopts(Socket, [{packet, line}]), + {ok, L1} = gen_tcp:recv(Socket, 0, 10000), + ?assertMatch( + <<"HTTP/1.1 204 No Content\r\n">>, + L1 + ), + ok = inet:setopts(Socket, [{packet, raw}]), + {ok, _RspHdrs} = gen_tcp:recv(Socket, 0, 10000), + ok = gen_tcp:close(Socket), + URI = + lists:flatten( + io_lib:format( + "http://~w.~w.~w.~w:~w/ets_file/filename/~s", + [A, B, C, D, Port, <<"K0005">>] + ) + ), + {ok, {{"HTTP/1.1", 200, "OK"}, _FetchHeaders, FetchBody}} = + httpc:request( + get, + {URI, []}, + [], + [{body_format, binary}], + test_client + ), + ?assert(is_binary(FetchBody)), + ?assertMatch(41020, byte_size(FetchBody)), + ReturnedHash = crypto:hash(md5, FetchBody), + ?assertMatch(Hash, ReturnedHash) + end. + +send_chunked_4KBobject(Socket) -> + TestValue = crypto:strong_rand_bytes((10 * 4092) + 100), + Hash = crypto:hash(md5, TestValue), + << + Chunk1:4092/binary, + Chunk2:4092/binary, + Chunk3:4092/binary, + Chunk4:4092/binary, + Chunk5:4092/binary, + Chunk6:4092/binary, + Chunk7:4092/binary, + Chunk8:4092/binary, + Chunk9:4092/binary, + Chunk10:4092/binary, + Chunk11:100/binary + >> = TestValue, + lists:foreach( + fun(Chunk) -> + Size = integer_to_binary(byte_size(Chunk), 16), + Bin = iolist_to_binary([Size, <<"\r\n">>, Chunk, <<"\r\n">>]), + gen_tcp:send(Socket, Bin) + end, + [ + Chunk1, + Chunk2, + Chunk3, + Chunk4, + Chunk5, + Chunk6, + Chunk7, + Chunk8, + Chunk9, + Chunk10, + Chunk11 + ] + ), + gen_tcp:send(Socket, <<"0\r\n\r\n">>), + Hash. + +put_then_get_file({A, B, C, D}, Port) -> + fun() -> + Key = <<"K0004">>, + URI = + lists:flatten( + io_lib:format( + "http://~w.~w.~w.~w:~w/ets_file/filename/~s", + [A, B, C, D, Port, Key] + ) + ), + Value = crypto:strong_rand_bytes(100 * 1024), + Hash = crypto:hash(md5, Value), + {ok, {{"HTTP/1.1", 204, "No Content"}, _Headers, <<>>}} = + httpc:request( + put, + {URI, [], "application/binary", Value}, + [], + [{body_format, binary}], + test_client + ), + {ok, {{"HTTP/1.1", 200, "OK"}, _FetchHeaders, FetchBody}} = + httpc:request( + get, + {URI, []}, + [], + [{body_format, binary}], + test_client + ), + ?assert(is_binary(FetchBody)), + ?assertMatch(102400, byte_size(FetchBody)), + ReturnedHash = crypto:hash(md5, FetchBody), + ?assertMatch(Hash, ReturnedHash) + end. + +put_then_get({A, B, C, D}, Port) -> + fun() -> + Key = <<"K0001">>, + URI = + lists:flatten( + io_lib:format( + "http://~w.~w.~w.~w:~w/ets_object/key/~s", + [A, B, C, D, Port, Key] + ) + ), + {ok, {{"HTTP/1.1", 404, "Not Found"}, Rsp1Headers, _Rsp1Body}} = + httpc:request( + get, + {URI, []}, + [], + [], + test_client + ), + ?assertMatch( + {"connection", "keep-alive"}, + lists:keyfind("connection", 1, Rsp1Headers) + ), + ?assertMatch( + {"server", "RiakAPI/4.0 SilverMachine"}, + lists:keyfind("server", 1, Rsp1Headers) + ), + Value = crypto:strong_rand_bytes(64), + ExpectedVTag = + binary_to_list( + base64:encode(crypto:hash(md5, Value), #{mode => urlsafe}) + ), + {ok, {{"HTTP/1.1", 204, "No Content"}, Rsp2Headers, <<>>}} = + httpc:request( + put, + {URI, [], "application/binary", Value}, + [], + [{body_format, binary}], + test_client + ), + ?assertMatch( + {"etag", ExpectedVTag}, + lists:keyfind("etag", 1, Rsp2Headers) + ), + {ok, {{"HTTP/1.1", 200, "OK"}, _Rsp3Headers, Rsp3Body}} = + httpc:request( + get, + {URI, []}, + [], + [{body_format, binary}], + test_client + ), + ?assert(is_binary(Rsp3Body)) + end. + +put_too_big({A, B, C, D}, Port) -> + fun() -> + Key = <<"K0002">>, + URI = + lists:flatten( + io_lib:format( + "http://~w.~w.~w.~w:~w/ets_object/key/~s", + [A, B, C, D, Port, Key] + ) + ), + Value = crypto:strong_rand_bytes(64 * 1024), + {ok, {{"HTTP/1.1", 413, "Content Too Large"}, _Rsp2Headers, <<>>}} = + httpc:request( + put, + {URI, [], "application/binary", Value}, + [], + [{body_format, binary}], + test_client + ) + end. + +put_big_header({A, B, C, D}, Port) -> + fun() -> + Key = <<"K0003">>, + URI = + lists:flatten( + io_lib:format( + "http://~w.~w.~w.~w:~w/ets_object/key/~s", + [A, B, C, D, Port, Key] + ) + ), + HeaderValue = + base64:encode(crypto:strong_rand_bytes(2048), #{mode => urlsafe}), + Value = crypto:strong_rand_bytes(64), + { + ok, + {{"HTTP/1.1", 431, "Request Header Fields Too Large"}, _, RspBdy} + } = + httpc:request( + put, + { + URI, + [{"X-Riak-Vclock", HeaderValue}], + "application/binary", + Value + }, + [], + [{body_format, binary}], + test_client + ), + ?assertMatch( + <<"Header x-riak-vclock exceeded maximum size of 1024">>, + RspBdy + ) + end. + +find_available_port([]) -> + no_port_found; +find_available_port([Port | Rest]) -> + case gen_tcp:listen(Port, []) of + {ok, Sock} -> + ok = gen_tcp:close(Sock), + Port; + _ -> + find_available_port(Rest) + end. + +-endif. diff --git a/test/riak_api_web_get_random.erl b/test/riak_api_web_get_random.erl new file mode 100644 index 0000000..999550b --- /dev/null +++ b/test/riak_api_web_get_random.erl @@ -0,0 +1,512 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Test handler that responds with random data. + +-module(riak_api_web_get_random). + +-if(?OTP_RELEASE == 26). +-feature(maybe_expr, enable). +-endif. + +-behaviour(riak_api_web_handler). + +-export( + [ + match_route/3, + check_permissions/5, + parse_query_params/2, + parse_request_headers/2, + process_request/2, + record_request/3 + ] +). + +-ifdef(TEST). +-export( + [ + setup/0, + generator/1, + cleanup/1, + request_single_value/3 + ] +). +-endif. + +-record(context, + { + request_id :: non_neg_integer()|undefined, + required_size :: non_neg_integer()|undefined + } +). + +-type context() :: #context{}. + +-define(ID_HEADER_LWR, <<"x-riak-request_id">>). + +%% @doc match_route for the module +-spec match_route( + riak_api_web_acceptor:method(), + unicode:chardata(), + list(unicode:chardata()) +) -> + nomatch | + {method_not_allowed, list(riak_api_web_acceptor:method())} | + {ok, riak_api_web_handler:limits(), context()}. +match_route('GET', <<"/random_data">>, _SP) -> + {ok, {10, 1024, 128 * 1024}, #context{}}; +match_route(_, <<"/random_data">>, _SP) -> + {method_not_allowed, ['GET']}; +match_route(_, _, _) -> + nomatch. + +%% @doc check_permissions for using this module or route +-spec + check_permissions( + riak_api_web_headers:headers(), + riak_api_web_socket:scheme(), + riak_api_web_handler:peer_ip(), + public_key:cert() | undefined, + context() + ) -> + {ok, context()}. +check_permissions(_Hdrs, _Scheme, _Peer, _Cert, Ctx) -> + {ok, Ctx}. + +%% @doc parse and validate query params, passed as a map +-spec + parse_query_params( + riak_api_web_handler:query_params(), + context() + ) -> + {ok, context()}|riak_api_web_acceptor:halt_response(). +parse_query_params([], #context{required_size = undefined}) -> + {halt, 400, [], <<"no required_size parameter">>, []}; +parse_query_params([], Ctx) -> + {ok, Ctx}; +parse_query_params([{<<"required_size">>, RS}|Rest], Ctx) -> + try + case binary_to_integer(RS) of + RSI when is_integer(RSI), RSI >= 0 -> + parse_query_params(Rest, Ctx#context{required_size = RSI}); + _BadRS -> + {halt, 400, [], <<"invalid required_size ~0p">>, [RS]} + end + catch + _ : _ -> + {halt, 400, [], <<"invalid required_size ~0p">>, [RS]} + end; +parse_query_params([_Other|Rest], Ctx) -> + parse_query_params(Rest, Ctx). + +%% @doc parse and validate the request headers +-spec + parse_request_headers( + riak_api_web_headers:headers(), + context() + ) -> + {ok, context()}|riak_api_web_acceptor:halt_response(). +parse_request_headers(ReqHeaders, Ctx) -> + case riak_api_web_headers:lookup(?ID_HEADER_LWR, ReqHeaders, true) of + undefined -> + ErrorMsg = <<"request requires x-riak-request_id header">>, + {halt, 400, [], ErrorMsg, []}; + {_OrigKey, [RequestIDStr]} when is_binary(RequestIDStr) -> + try + RequestID = binary_to_integer(RequestIDStr), + true = RequestID > 0, + {ok, Ctx#context{request_id = RequestID}} + catch + error:badarg -> + {halt, 400, [], <<"invalid non-numeric request_id">>, []}; + error:{badmatch,false} -> + {halt, 400, [], <<"invalid negative request_id">>, []} + end; + {_OrigKey, MultipleIDs} when is_list(MultipleIDs) -> + {halt, 400, [], <<"multiple request_id provided">>, []} + end. + +%% @doc Process the request and produce a response +-spec + process_request( + riak_api_web_body:req_body(), + context() + ) -> + { + ok, + { + riak_api_web_acceptor:response_code(), + riak_api_web_headers:header_list(), + riak_api_web_handler:response_body(), + boolean(), + riak_api_web_body:req_body() + }, + context() + }. +process_request(RqBdy, Ctx = #context{request_id = RqID, required_size = RS}) + when is_integer(RqID), is_integer(RS), RS > 0 -> + Body = crypto:strong_rand_bytes(RS), + RspHdr = + {<<"X-Riak-request_id">>, integer_to_binary(RqID)}, + { + ok, + {200, [RspHdr], Body, true, RqBdy}, + Ctx + }. + +%% @doc Record the output of the interaction +-spec record_request( + riak_api_web_handler:timings(), + riak_api_web_handler:completion(), + context() +) -> + ok. +record_request(Timings, Completion, _Ctx) -> + {A, B, C} = Timings, + io:format( + user, + "Request ~w with timings ~0p~n", + [Completion, {B - A, C - B, C - A}] + ). + + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +basic_handler_test_() -> + {setup, fun setup/0, fun cleanup/1, fun generator/1}. + +-define(REQUEST_BIN(ID, Size, KeepAlive), + io_lib:format( + << + "GET /random_data?required_size=~w HTTP/1.1\r\n" + "X-Riak-request_id: ~w\r\n" + "Connection: ~s\r\n" + "Content-Length: 0\r\n" + "\r\n" + >>, + [Size, ID, KeepAlive] + ) +). + +-define(REQUEST_BIN_V10(ID, Size, KeepAlive), + io_lib:format( + << + "GET /random_data?required_size=~w HTTP/1.0\r\n" + "X-Riak-request_id: ~w\r\n" + "Connection: ~s\r\n" + "Content-Length: 0\r\n" + "\r\n" + >>, + [Size, ID, KeepAlive] + ) +). + +-define(BAD_VERSION, + << + "GET /random_data?required_size=~w HTTP1.1\r\n" + "X-Riak-request_id: 1\r\n" + "Connection: close\r\n" + "Content-Length: 0\r\n" + "\r\n" + >> +). + +-define(WRONG_URL, + << + "GET /randon_data?required_size=~w HTTP/1.1\r\n" + "X-Riak-request_id: 1\r\n" + "Connection: close\r\n" + "Content-Length: 0\r\n" + "\r\n" + >> +). + +-define(POST_NOT_GET, + << + "POST /random_data?required_size=~w HTTP/1.1\r\n" + "X-Riak-request_id: 1\r\n" + "Connection: close\r\n" + "Content-Length: 0\r\n" + "\r\n" + >> +). + +setup() -> + inets:start(), + TestPort = find_available_port(lists:seq(8000, 8999)), + IPAddr = {127, 0, 0, 1}, + SpecName = riak_api_web:spec_name(http, IPAddr, TestPort), + Options = + [ + {name, SpecName}, + {ip, IPAddr}, + {port, TestPort}, + {acceptor_pool_start_size, 4}, + {acceptor_pool_max_size, 8} + ], + {ok, _Pid} = riak_api_web_socket:start_link(Options), + riak_api_web:add_routes(TestPort, [{10, ?MODULE}]), + {ok, _HTTPC} = inets:start(httpc, [{profile, test_client}]), + ok = httpc:set_options([{verbose, false}], test_client), + {SpecName, IPAddr, TestPort} + . + +generator({_SpecName, IPAddr, Port}) -> + [ + request_single_value(IPAddr, Port, 32), + request_single_value(IPAddr, Port, 64), + request_single_value(IPAddr, Port, 2048), + pipeline_request_values(IPAddr, Port, 16), + request_keepalive_v10(IPAddr, Port, 256), + request_error(IPAddr, Port, ?WRONG_URL, 404), + request_error(IPAddr, Port, ?POST_NOT_GET, 405), + request_error(IPAddr, Port, ?BAD_VERSION, 400), + request_with_httpc(IPAddr, Port, 128), + request_with_httpc(IPAddr, Port, 16) + ]. + +cleanup({SpecName, _IPAddr, _Port}) -> + ok = inets:stop(), + ?assertMatch(4, riak_api_web_socket:get_active_pool_size(SpecName)), + riak_api_web_socket:stop(SpecName), + ok. + +request_error(IPAddr, Port, Msg, ExpectedCode) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + ok = gen_tcp:send(Socket, Msg), + {ok, Data} = gen_tcp:recv(Socket, 0), + ?assertMatch(ok, validate_error(Data, ExpectedCode, Socket)), + ok = gen_tcp:close(Socket) + end. + +request_with_httpc({A, B, C, D}, Port, Size) -> + fun() -> + URI = + lists:flatten( + io_lib:format( + "http://~w.~w.~w.~w:~w/random_data?required_size=~w", + [A, B, C, D, Port, Size] + ) + ), + {ok, {{"HTTP/1.1", 200, "OK"}, ResponseHeaders, ResponseBody}} = + httpc:request( + get, + { + URI, + [{"X-Riak-request_id", integer_to_binary(1)}] + }, + [], + [], + test_client + ), + ?assertMatch( + {"connection", "keep-alive"}, + lists:keyfind("connection", 1, ResponseHeaders) + ), + ?assertMatch( + {"server", "RiakAPI/4.0 SilverMachine"}, + lists:keyfind("server", 1, ResponseHeaders) + ), + SizeL = integer_to_list(Size), + ?assertMatch( + {"content-length", SizeL}, + lists:keyfind("content-length", 1, ResponseHeaders) + ), + ?assertMatch( + {"x-riak-request_id", "1"}, + lists:keyfind("x-riak-request_id", 1, ResponseHeaders) + ), + ?assertMatch(Size, length(ResponseBody)) + end. + +request_single_value(IPAddr, Port, Size) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}, {recbuf, 64 * 1024}] + ), + Request = ?REQUEST_BIN(1, Size, <<"close">>), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + ?assertMatch(<<>>, validate_response(Data, Size, Socket)), + ok = gen_tcp:close(Socket) + end. + +request_keepalive_v10(IPAddr, Port, Size) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = + list_to_binary( + lists:flatten(?REQUEST_BIN_V10(1, Size, <<"keep-alive">>)) + ), + ok = gen_tcp:send(Socket, Request), + {ok, Data1} = gen_tcp:recv(Socket, 0), + ?assertMatch( + <<>>, + validate_response(Data1, Size, Socket, <<"HTTP/1.0 200 OK\r\n">>) + ), + ok = gen_tcp:send(Socket, Request), + {ok, Data2} = gen_tcp:recv(Socket, 0), + ?assertMatch( + <<>>, + validate_response(Data2, Size, Socket, <<"HTTP/1.0 200 OK\r\n">>) + ), + ok = gen_tcp:close(Socket) + end. + +pipeline_request_values(IPAddr, Port, Size) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Requests = + lists:map( + fun(I) -> ?REQUEST_BIN(I, Size, <<"keep-alive">>) end, + lists:seq(1, 5) + ), + Request = iolist_to_binary(Requests), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + R1 = validate_response(Data, Size, Socket), + R2 = validate_response(R1, Size, Socket), + R3 = validate_response(R2, Size, Socket), + R4 = validate_response(R3, Size, Socket), + <<>> = validate_response(R4, Size, Socket), + ok = gen_tcp:close(Socket) + end. + +extract_headers(Data, Socket, ExpectedResponseLine) -> + maybe + {ok, L1, R1} ?= erlang:decode_packet(line, Data, []), + ?assertMatch(L1, ExpectedResponseLine), + {ok, L2, R2} ?= erlang:decode_packet(line, R1, []), + {ok, L3, R3} ?= erlang:decode_packet(line, R2, []), + {ok, L4, R4} ?= erlang:decode_packet(line, R3, []), + {ok, L5, R5} ?= erlang:decode_packet(line, R4, []), + {ok, MaybeL6, R6} ?= erlang:decode_packet(line, R5, []), + {ok, L6, Rem} ?= + case MaybeL6 of + <<"\r\n">> -> + {ok, none, R6}; + MaybeL6 -> + case erlang:decode_packet(line, R6, []) of + {ok, <<"\r\n">>, R7} -> + {ok, MaybeL6, R7}; + {more, _} -> + {more, undefined} + end + end, + + { + lists:map( + fun(S) -> hd(string:split(S, <<":">>, leading)) end, + lists:filter( + fun(H) -> H =/= none end, + lists:sort([L2, L3, L4, L5, L6]) + ) + ), + Rem + } + else + {more, _} -> + {ok, More} = gen_tcp:recv(Socket, 0), + extract_headers( + <>, + Socket, + ExpectedResponseLine + ) + end. + +validate_response(Data, Size, Socket) -> + validate_response(Data, Size, Socket, <<"HTTP/1.1 200 OK\r\n">>). + +validate_response(Data, Size, Socket, StatusLine) -> + {HeaderKeys, Rem} = + extract_headers(Data, Socket, StatusLine), + ?assertMatch( + [ + <<"Connection">>, + <<"Content-Length">>, + <<"Date">>, + <<"Server">>, + <<"X-Riak-request_id">> + ], + HeaderKeys + ), + {ok, RspBody, Rest} = erlang:decode_packet(0, Rem, []), + <> = RspBody, + ?assertMatch(Size, byte_size(ExpectedBody)), + <>. + +validate_error(Data, ExpectedCode, Socket) -> + {ExpectedResponseLine, AdditionalHeaderKeys} = + case ExpectedCode of + 400 -> + {<<"HTTP/1.0 400 Bad Request\r\n">>, []}; + % As it was a bad version - can't assume 1.1 + 404 -> + {<<"HTTP/1.1 404 Not Found\r\n">>, []}; + 405 -> + {<<"HTTP/1.1 405 Method Not Allowed\r\n">>, [<<"Allow">>]} + end, + {HeaderKeys, _Rem} = extract_headers(Data, Socket, ExpectedResponseLine), + ExpectedHeaderKeys = + lists:sort( + [ + <<"Connection">>, + <<"Content-Length">>, + <<"Date">>, + <<"Server">> + ] ++ AdditionalHeaderKeys + ), + ?assertMatch(ExpectedHeaderKeys, HeaderKeys). + +find_available_port([]) -> + no_port_found; +find_available_port([Port|Rest]) -> + case gen_tcp:listen(Port, []) of + {ok, Sock} -> + ok = gen_tcp:close(Sock), + Port; + _ -> + find_available_port(Rest) + end. + +-endif. \ No newline at end of file diff --git a/test/riak_api_web_trigger.erl b/test/riak_api_web_trigger.erl new file mode 100644 index 0000000..949e526 --- /dev/null +++ b/test/riak_api_web_trigger.erl @@ -0,0 +1,540 @@ +%% ------------------------------------------------------------------- +%% +%% Copyright (c) 2026 Martin Sumner +%% +%% This file is provided to you under the Apache License, +%% Version 2.0 (the "License"); you may not use this file +%% except in compliance with the License. You may obtain +%% a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% +%% ------------------------------------------------------------------- +%% @doc Test handler that is used for triggering error conditions + +-module(riak_api_web_trigger). + +-behaviour(riak_api_web_handler). + +-export( + [ + match_route/3, + check_permissions/5, + parse_query_params/2, + parse_request_headers/2, + process_request/2, + record_request/3 + ] +). + +-ifdef(TEST). +-export( + [ + setup/0, + generator/1, + cleanup/1 + ] +). +-endif. + +-record(context, { + mishandle_nonzero_body = false :: boolean(), + response_code = 200 :: 200 | 201 | 204 +}). + +-type context() :: #context{}. + +%% @doc match_route for the module +-spec match_route( + riak_api_web_acceptor:method(), + unicode:chardata(), + list(unicode:chardata()) +) -> + nomatch + | {method_not_allowed, list(riak_api_web_acceptor:method())} + | {ok, riak_api_web_handler:limits(), context()}. +match_route('PUT', _P, [<<"with_limits">>, HC, HS, BS]) when + is_binary(HC), is_binary(HS), is_binary(BS) +-> + { + ok, + { + binary_to_integer(HC), + binary_to_integer(HS), + binary_to_integer(BS) + }, + #context{} + }; +match_route(_, _, _) -> + nomatch. + +%% @doc check_permissions for using this module or route +-spec check_permissions( + riak_api_web_headers:headers(), + riak_api_web_socket:scheme(), + riak_api_web_handler:peer_ip(), + public_key:cert() | undefined, + context() +) -> + {ok, context()}. +check_permissions(_Hdrs, _Scheme, _Peer, _Cert, Ctx) -> + {ok, Ctx}. + +%% @doc parse and validate query params, passed as a map +-spec parse_query_params( + riak_api_web_handler:query_params(), + context() +) -> + {ok, context()} | riak_api_web_acceptor:halt_response(). +parse_query_params(QueryParams, Ctx) -> + {ok, Ctx1} = + case lists:keyfind(<<"mishandle_nonzero_body">>, 1, QueryParams) of + {<<"mishandle_nonzero_body">>, true} -> + {ok, Ctx#context{mishandle_nonzero_body = true}}; + _ -> + {ok, Ctx} + end, + case lists:keyfind(<<"response_code">>, 1, QueryParams) of + {<<"response_code">>, RC} when is_binary(RC) -> + {ok, Ctx1#context{response_code = binary_to_integer(RC)}}; + _ -> + {ok, Ctx1} + end. + +%% @doc parse and validate the request headers +-spec parse_request_headers( + riak_api_web_headers:headers(), + context() +) -> + {ok, context()} | riak_api_web_acceptor:halt_response(). +parse_request_headers(_ReqHeaders, Ctx) -> + {ok, Ctx}. + +%% @doc Process the request and produce a response +-spec process_request( + riak_api_web_body:req_body() | none, + context() +) -> + { + ok, + { + riak_api_web_acceptor:response_code(), + riak_api_web_headers:header_list(), + riak_api_web_handler:response_body(), + boolean(), + riak_api_web_body:req_body() | none + }, + context() + }. +process_request(RqBdy, Ctx) -> + case {Ctx#context.mishandle_nonzero_body, RqBdy} of + {true, RqBdy} when RqBdy =/= none -> + {ok, {200, [], <<>>, true, none}, Ctx}; + {false, none} -> + {ok, {Ctx#context.response_code, [], <<>>, true, RqBdy}, Ctx}; + {false, RqBdy} -> + case riak_api_web_body:get_body(RqBdy, all, 10000) of + {Buffer, UpdBdy} when Buffer =/= error -> + {ok, {Ctx#context.response_code, [], <<>>, true, UpdBdy}, + Ctx} + end + end. + +%% @doc Record the output of the interaction +-spec record_request( + riak_api_web_handler:timings(), + riak_api_web_handler:completion(), + context() +) -> + ok. +record_request(_Timings, _Completion, _Ctx) -> + ok. + +%%%============================================================================ +%%% Eunit tests +%%%============================================================================ + +-ifdef(TEST). +-include_lib("eunit/include/eunit.hrl"). + +basic_handler_test_() -> + {setup, fun setup/0, fun cleanup/1, fun generator/1}. + +setup() -> + inets:start(), + TestPort = find_available_port(lists:seq(8000, 8999)), + IPAddr = {127, 0, 0, 1}, + SpecName = riak_api_web:spec_name(http, IPAddr, TestPort), + Options = + [ + {name, SpecName}, + {ip, IPAddr}, + {port, TestPort}, + {acceptor_pool_start_size, 4} + ], + {ok, _Pid} = riak_api_web_socket:start_link(Options), + riak_api_web:add_routes([{10, ?MODULE}]), + {ok, _HTTPC} = inets:start(httpc, [{profile, test_client}]), + ok = httpc:set_options([{verbose, false}], test_client), + {SpecName, IPAddr, TestPort}. + +generator({_SpecName, IPAddr, Port}) -> + [ + too_many_headers(IPAddr, Port), + header_too_large(IPAddr, Port), + non_zero_body(IPAddr, Port), + zero_body(IPAddr, Port), + mishandle_nonzero_body(IPAddr, Port), + handle_bad_uri(IPAddr, Port), + handle_bad_content_length(IPAddr, Port), + handle_connection_header_confusion(IPAddr, Port), + trigger_alternative_response_code( + IPAddr, + Port, + <<"1.0">>, + 201, + <<"HTTP/1.0 201 Created\r\n">> + ), + trigger_alternative_response_code( + IPAddr, + Port, + <<"1.0">>, + 204, + <<"HTTP/1.0 204 No Content\r\n">> + ), + trigger_alternative_response_code( + IPAddr, + Port, + <<"1.0">>, + 202, + <<"HTTP/1.0 202 Accepted\r\n">> + ), + trigger_alternative_response_code( + IPAddr, + Port, + <<"1.1">>, + 201, + <<"HTTP/1.1 201 Created\r\n">> + ), + trigger_alternative_response_code( + IPAddr, + Port, + <<"1.1">>, + 204, + <<"HTTP/1.1 204 No Content\r\n">> + ), + trigger_alternative_response_code( + IPAddr, + Port, + <<"1.1">>, + 202, + <<"HTTP/1.1 202 Accepted\r\n">> + ) + ]. + +request_bin(HC, HS, BS, HeaderSize, BodySize) -> + request_bin(HC, HS, BS, <<"">>, HeaderSize, BodySize). + +request_bin(HC, HS, BS, QP, HeaderSize, BodySize) -> + <> = + base64:encode(crypto:strong_rand_bytes(HeaderSize)), + Body = + case BodySize of + "A" -> + crypto:strong_rand_bytes(10); + _ -> + crypto:strong_rand_bytes(BodySize) + end, + Rq = + io_lib:format( + << + "PUT /with_limits/~w/~w/~w?~s HTTP/1.1\r\n" + "Connection: close\r\n" + "Content-Length: ~w\r\n" + "X-Riak-BigHeader: ~s\r\n" + "Content-Type: application/octet-stream\r\n" + "\r\n" + "~w" + >>, + [HC, HS, BS, QP, BodySize, Header, Body] + ), + iolist_to_binary(Rq). + +request_bin_cc(HC, HS, BS, HeaderSize, BodySize, Version) -> + <> = + base64:encode(crypto:strong_rand_bytes(HeaderSize)), + Body = crypto:strong_rand_bytes(BodySize), + Rq = + io_lib:format( + << + "PUT /with_limits/~w/~w/~w?QP HTTP/~s\r\n" + "Connection: close\r\n" + "Connection: keep-alive\r\n" + "Content-Length: ~w\r\n" + "X-Riak-BigHeader: ~s\r\n" + "Content-Type: application/octet-stream\r\n" + "\r\n" + "~w" + >>, + [HC, HS, BS, Version, BodySize, Header, Body] + ), + iolist_to_binary(Rq). + +request_bin_rc(HC, HS, BS, RC, HeaderSize, BodySize, Version) -> + <> = + base64:encode(crypto:strong_rand_bytes(HeaderSize)), + Body = crypto:strong_rand_bytes(BodySize), + Rq = + io_lib:format( + << + "PUT /with_limits/~w/~w/~w?response_code=~w HTTP/~s\r\n" + "Connection: keep-alive\r\n" + "Content-Length: ~w\r\n" + "X-Riak-BigHeader: ~s\r\n" + "Content-Type: application/octet-stream\r\n" + "\r\n" + "~w" + >>, + [HC, HS, BS, RC, Version, BodySize, Header, Body] + ), + iolist_to_binary(Rq). + +too_many_headers(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin(1, 1024, 1024, 32, 64), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 431 Request Header Fields Too Large\r\n">>, + L1 + ), + ?assertNotMatch( + nomatch, + string:find(R1, <<"Headers exceeded maximum count of 1">>) + ), + ok = gen_tcp:close(Socket) + end. + +header_too_large(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin(16, 64, 1024, 256, 64), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 431 Request Header Fields Too Large\r\n">>, + L1 + ), + ?assertNotMatch( + nomatch, + string:find( + R1, + <<"Header X-Riak-BigHeader exceeded maximum size of 64">> + ) + ), + ok = gen_tcp:close(Socket) + end. + +non_zero_body(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin(16, 2048, 0, 32, 64), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, _R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 413 Content Too Large\r\n">>, + L1 + ), + ok = gen_tcp:close(Socket) + end. + +zero_body(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin(16, 2048, 0, 32, 0), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, _R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 200 OK\r\n">>, + L1 + ), + ok = gen_tcp:close(Socket) + end. + +mishandle_nonzero_body(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + QP = <<"mishandle_nonzero_body">>, + Request = request_bin(16, 2048, 1024, QP, 32, 64), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, _R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 500 Internal Server Error\r\n">>, + L1 + ), + ok = gen_tcp:close(Socket) + end. + +handle_bad_uri(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin(<<"badly_encoded_con%0tent">>, 2048, 0, 32, 64), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, _R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 400 Bad Request\r\n">>, + L1 + ), + ok = gen_tcp:close(Socket) + end. + +handle_bad_content_length(IPAddr, Port) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin(8, 512, 96, 32, "A"), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, _R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + <<"HTTP/1.1 400 Bad Request\r\n">>, + L1 + ), + ok = gen_tcp:close(Socket) + end. + +handle_connection_header_confusion(IPAddr, Port) -> + fun() -> + {ok, Socket10} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request10 = request_bin_cc(8, 512, 96, 32, 32, <<"1.0">>), + ok = gen_tcp:send(Socket10, Request10), + {ok, Data10} = gen_tcp:recv(Socket10, 0), + {ok, L1, R1} = erlang:decode_packet(line, Data10, []), + ?assertMatch( + <<"HTTP/1.0 200 OK\r\n">>, + L1 + ), + ?assertMatch( + nomatch, + string:find(R1, <<"Connection: keep-alive">>) + ), + ?assertNotMatch( + nomatch, + string:find(R1, <<"Connection: close">>) + ), + ok = gen_tcp:close(Socket10), + {ok, Socket11} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request11 = request_bin_cc(8, 512, 96, 32, 32, <<"1.1">>), + ok = gen_tcp:send(Socket11, Request11), + {ok, Data11} = gen_tcp:recv(Socket11, 0), + {ok, L2, R2} = erlang:decode_packet(line, Data11, []), + ?assertMatch( + <<"HTTP/1.1 200 OK\r\n">>, + L2 + ), + ?assertNotMatch( + nomatch, + string:find(R2, <<"Connection: keep-alive">>) + ), + ?assertMatch( + nomatch, + string:find(R2, <<"Connection: close">>) + ), + ok = gen_tcp:close(Socket10) + end. + +trigger_alternative_response_code(IPAddr, Port, Version, RC, RM) -> + fun() -> + {ok, Socket} = + gen_tcp:connect( + IPAddr, + Port, + [binary, {packet, raw}, {active, false}] + ), + Request = request_bin_rc(16, 512, 256, RC, 32, 64, Version), + ok = gen_tcp:send(Socket, Request), + {ok, Data} = gen_tcp:recv(Socket, 0), + {ok, L1, _R1} = erlang:decode_packet(line, Data, []), + ?assertMatch( + RM, + L1 + ), + ok = gen_tcp:close(Socket) + end. + +cleanup({SpecName, _IPAddr, _Port}) -> + ok = inets:stop(), + ?assertMatch(4, riak_api_web_socket:get_active_pool_size(SpecName)), + riak_api_web_socket:stop(SpecName), + ok. + +find_available_port([]) -> + no_port_found; +find_available_port([Port | Rest]) -> + case gen_tcp:listen(Port, []) of + {ok, Sock} -> + ok = gen_tcp:close(Sock), + Port; + _ -> + find_available_port(Rest) + end. + +-endif.