Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/firebase_functions/alerts_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import T, CloudEvent as _CloudEvent
from firebase_functions.core import T, CloudEvent as _CloudEvent, _with_init
from firebase_functions.options import FirebaseAlertOptions

# Explicitly import AlertType to make it available in the public API.
Expand Down Expand Up @@ -95,7 +95,7 @@ def on_alert_published_inner_decorator(func: OnAlertPublishedCallable):
@_functools.wraps(func)
def on_alert_published_wrapped(raw: _ce.CloudEvent):
from firebase_functions.private._alerts_fn import alerts_event_from_ce
func(alerts_event_from_ce(raw))
_with_init(func)(alerts_event_from_ce(raw))

_util.set_func_endpoint_attr(
on_alert_published_wrapped,
Expand Down
47 changes: 47 additions & 0 deletions src/firebase_functions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import datetime as _datetime
import typing as _typing

from . import logger as _logger

T = _typing.TypeVar("T")


Expand Down Expand Up @@ -80,3 +82,48 @@ class Change(_typing.Generic[T]):
"""
The state of data after the change.
"""


_did_init = False
_init_callback: _typing.Callable[[], _typing.Any] | None = None


def init(callback: _typing.Callable[[], _typing.Any]) -> None:
"""
Registers a function that should be run when in a production environment
before executing any functions code.
Calling this decorator more than once leads to undefined behavior.
"""

global _did_init
global _init_callback

_init_callback = callback

if _did_init:
_logger.warn(
"Setting init callback more than once. Only the most recent callback will be called"
)

_init_callback = callback
Comment thread
exaby73 marked this conversation as resolved.
_did_init = False


def _with_init(
fn: _typing.Callable[...,
_typing.Any]) -> _typing.Callable[..., _typing.Any]:
"""
A decorator that runs the init callback before running the decorated function.
"""

def wrapper(*args, **kwargs):
global _did_init

if not _did_init:
if _init_callback is not None:
_init_callback()
_did_init = True

return fn(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion src/firebase_functions/db_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _db_endpoint_handler(
subject=event_attributes["subject"],
params=params,
)
func(database_event)
_core._with_init(func)(database_event)


@_util.copy_func_kwargs(DatabaseOptions)
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/eventarc_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import firebase_functions.options as _options
import firebase_functions.private.util as _util
from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init


@_util.copy_func_kwargs(_options.EventarcTriggerOptions)
Expand Down Expand Up @@ -73,7 +73,7 @@ def on_custom_event_published_wrapped(raw: _ce.CloudEvent):
),
type=event_dict["type"],
)
func(event)
_with_init(func)(event)

_util.set_func_endpoint_attr(
on_custom_event_published_wrapped,
Expand Down
2 changes: 2 additions & 0 deletions src/firebase_functions/firestore_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def _firestore_endpoint_handler(
params=params,
)

func = _core._with_init(func)

if event_type.endswith(".withAuthContext"):
database_event_with_auth_context = AuthEvent(**vars(database_event),
auth_type=event_auth_type,
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/https_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _on_call_handler(func: _C2,
instance_id_token=request.headers.get(
"Firebase-Instance-ID-Token"),
)
result = func(context)
result = _core._with_init(func)(context)
return _jsonify(result=result)
# Disable broad exceptions lint since we want to handle all exceptions here
# and wrap as an HttpsError.
Expand Down Expand Up @@ -447,7 +447,7 @@ def on_request_wrapped(request: Request) -> Response:
methods=options.cors.cors_methods,
origins=options.cors.cors_origins,
)(func)(request)
return func(request)
return _core._with_init(func)(request)

_util.set_func_endpoint_attr(
on_request_wrapped,
Expand Down
9 changes: 5 additions & 4 deletions src/firebase_functions/private/_identity_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud functions to handle Eventarc events."""

# pylint: disable=protected-access
import typing as _typing
import datetime as _dt
import time as _time
import json as _json

from firebase_functions.core import _with_init
from firebase_functions.https_fn import HttpsError, FunctionsErrorCode

import firebase_functions.private.util as _util
Expand Down Expand Up @@ -351,8 +352,8 @@ def before_operation_handler(
jwt_token = request.json["data"]["jwt"]
decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token)
event = _auth_blocking_event_from_token_data(decoded_token)
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func(
event)
auth_response: BeforeCreateResponse | BeforeSignInResponse | None = _with_init(
func)(event)
if not auth_response:
return _jsonify({})
auth_response_dict = _validate_auth_response(event_type, auth_response)
Expand All @@ -362,7 +363,7 @@ def before_operation_handler(
# pylint: disable=broad-except
except Exception as exception:
if not isinstance(exception, HttpsError):
_logging.error("Unhandled error", exception)
_logging.error("Unhandled error %s", exception)
exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
status = exception._http_error_code.status
return _make_response(_jsonify(error=exception._as_dict()), status)
4 changes: 2 additions & 2 deletions src/firebase_functions/pubsub_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import CloudEvent, T
from firebase_functions.core import CloudEvent, T, _with_init
from firebase_functions.options import PubSubOptions


Expand Down Expand Up @@ -151,7 +151,7 @@ def _message_handler(
type=event_dict["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(PubSubOptions)
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/remote_config_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init
from firebase_functions.options import EventHandlerOptions


Expand Down Expand Up @@ -189,7 +189,7 @@ def _config_handler(func: _C1, raw: _ce.CloudEvent) -> None:
type=event_dict["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(EventHandlerOptions)
Expand Down
3 changes: 2 additions & 1 deletion src/firebase_functions/scheduler_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
make_response as _make_response,
)

from firebase_functions.core import _with_init
# Export for user convenience.
# pylint: disable=unused-import
from firebase_functions.options import Timezone
Expand Down Expand Up @@ -108,7 +109,7 @@ def on_schedule_wrapped(request: _Request) -> _Response:
schedule_time=schedule_time,
)
try:
func(event)
_with_init(func)(event)
return _make_response()
# Disable broad exceptions lint since we want to handle all exceptions.
# pylint: disable=broad-except
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/storage_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import cloudevents.http as _ce

import firebase_functions.private.util as _util
from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init
from firebase_functions.options import StorageOptions

_event_type_archived = "google.cloud.storage.object.v1.archived"
Expand Down Expand Up @@ -255,7 +255,7 @@ def _message_handler(
type=event_attributes["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(StorageOptions)
Expand Down
4 changes: 2 additions & 2 deletions src/firebase_functions/test_lab_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import firebase_functions.private.util as _util

from firebase_functions.core import CloudEvent
from firebase_functions.core import CloudEvent, _with_init
from firebase_functions.options import EventHandlerOptions


Expand Down Expand Up @@ -246,7 +246,7 @@ def _event_handler(func: _C1, raw: _ce.CloudEvent) -> None:
type=event_dict["type"],
)

func(event)
_with_init(func)(event)


@_util.copy_func_kwargs(EventHandlerOptions)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Tests for the db module.
"""

import unittest
from unittest import mock
from cloudevents.http import CloudEvent
from firebase_functions import core, db_fn


class TestDb(unittest.TestCase):
"""
Tests for the db module.
"""

def test_calls_init_function(self):
hello = None

@core.init
def init():
nonlocal hello
hello = "world"

func = mock.Mock(__name__="example_func")
decorated_func = db_fn.on_value_created(reference="path")(func)

event = CloudEvent(attributes={
"specversion": "1.0",
"id": "id",
"source": "source",
"subject": "subject",
"type": "type",
"time": "2024-04-10T12:00:00.000Z",
"instance": "instance",
"ref": "ref",
"firebasedatabasehost": "firebasedatabasehost",
"location": "location",
},
data={"delta": "delta"})

decorated_func(event)

self.assertEqual(hello, "world")
34 changes: 34 additions & 0 deletions tests/test_eventarc_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
"""Eventarc trigger function tests."""
import unittest
from unittest.mock import Mock

from cloudevents.http import CloudEvent as _CloudEvent

from firebase_functions import core
from firebase_functions.core import CloudEvent
from firebase_functions.eventarc_fn import on_custom_event_published

Expand Down Expand Up @@ -83,3 +86,34 @@ def test_on_custom_event_published_wrapped(self):
event_arg.type,
"firebase.extensions.storage-resize-images.v1.complete",
)

def test_calls_init_function(self):
hello = None

@core.init
def init():
nonlocal hello
hello = "world"

func = Mock(__name__="example_func")
raw_event = _CloudEvent(
attributes={
"specversion": "1.0",
"type": "firebase.extensions.storage-resize-images.v1.complete",
"source": "https://example.com/testevent",
"id": "1234567890",
"subject": "test_subject",
"time": "2023-03-11T13:25:37.403Z",
},
data={
"some_key": "some_value",
},
)

decorated_func = on_custom_event_published(
event_type="firebase.extensions.storage-resize-images.v1.complete",
)(func)

decorated_func(raw_event)

self.assertEqual(hello, "world")
51 changes: 51 additions & 0 deletions tests/test_firestore_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,54 @@ def test_firestore_endpoint_handler_calls_function_with_correct_args(self):
self.assertIsInstance(event, AuthEvent)
self.assertEqual(event.auth_type, "unauthenticated")
self.assertEqual(event.auth_id, "foo")

def test_calls_init_function(self):
with patch.dict("sys.modules", mocked_modules):
from firebase_functions import firestore_fn, core
from cloudevents.http import CloudEvent

func = Mock(__name__="example_func")

hello = None

@core.init
def init():
nonlocal hello
hello = "world"

attributes = {
"specversion":
"1.0",
# pylint: disable=protected-access
"type":
firestore_fn._event_type_created,
"source":
"https://example.com/testevent",
"time":
"2023-03-11T13:25:37.403Z",
"subject":
"test_subject",
"datacontenttype":
"application/json",
"location":
"projects/project-id/databases/(default)/documents/foo/{bar}",
"project":
"project-id",
"namespace":
"(default)",
"document":
"foo/{bar}",
"database":
"projects/project-id/databases/(default)",
"authtype":
"unauthenticated",
"authid":
"foo"
}
raw_event = CloudEvent(attributes=attributes, data=json.dumps({}))
decorated_func = firestore_fn.on_document_created(
document="/foo/{bar}")(func)

decorated_func(raw_event)

self.assertEqual(hello, "world")
Loading