Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions tests/test_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,41 @@ def test_create_trip_add_note(self):

self.assertEqual(data['trip']['notes'], [note_id])

def test_create_trip_add_existing_note(self):
response = self.client.post('/trips', data=json.dumps(self.trip), headers=self.headers_json)
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))
trip_id = data['id']

# Create note
response = self.client.post('/notes', data=json.dumps(self.note), headers=self.headers_json)
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))
note_id = data['id']

# Add note to trip
response = self.client.patch(
f'/trips/{trip_id}/notes', data=json.dumps({'note_id': note_id}), headers=self.headers_json
)
self.assertEqual(response.status_code, 201)

response = self.client.get(f'/trips/{trip_id}', headers=self.headers)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode('utf-8'))

self.assertEqual(len(data['trip']['notes']), 1)
self.assertEqual(data['trip']['notes'], [note_id])

# Add note again to trip
response = self.client.patch(
f'/trips/{trip_id}/notes', data=json.dumps({'note_id': note_id}), headers=self.headers_json
)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode('utf-8'))

self.assertEqual(len(data['trip']['notes']), 1)
self.assertEqual(data['trip']['notes'], [note_id])

def test_create_trip_add_route(self):
response = self.client.post('/trips', data=json.dumps(self.trip), headers=self.headers_json)
self.assertEqual(response.status_code, 201)
Expand All @@ -232,10 +267,39 @@ def test_create_trip_add_route(self):
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))

response = self.client.get(f'/trips/{trip_id}', headers=self.headers)
self.assertEqual(len(data['trip']['routes']), 1)
self.assertEqual(data['trip']['routes'], [route_id])

def test_create_trip_add_existing_route(self):
response = self.client.post('/trips', data=json.dumps(self.trip), headers=self.headers_json)
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))
trip_id = data['id']

# Create route
response = self.client.post('/routes', data=json.dumps(self.route), headers=self.headers_json)
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))
route_id = data['id']

# Add route to trip
response = self.client.patch(
f'/trips/{trip_id}/routes', data=json.dumps({'route_id': route_id}), headers=self.headers_json
)
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))

self.assertEqual(len(data['trip']['routes']), 1)
self.assertEqual(data['trip']['routes'], [route_id])

# Add existing route to trip
response = self.client.patch(
f'/trips/{trip_id}/routes', data=json.dumps({'route_id': route_id}), headers=self.headers_json
)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode('utf-8'))

self.assertEqual(len(data['trip']['routes']), 1)
self.assertEqual(data['trip']['routes'], [route_id])

def test_create_trip_add_item_list(self):
Expand All @@ -260,10 +324,19 @@ def test_create_trip_add_item_list(self):
self.assertEqual(response.status_code, 201)
data = json.loads(response.data.decode('utf-8'))

response = self.client.get(f'/trips/{trip_id}', headers=self.headers)
self.assertEqual(len(data['trip']['item_lists']), 1)
self.assertEqual(data['trip']['item_lists'], [item_list_id])

# Add existing item_list to trip
response = self.client.patch(
f'/trips/{trip_id}/item_lists',
data=json.dumps({'item_list_id': item_list_id}),
headers=self.headers_json,
)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode('utf-8'))

self.assertEqual(len(data['trip']['item_lists']), 1)
self.assertEqual(data['trip']['item_lists'], [item_list_id])

def test_change_trip_owner(self):
Expand Down
13 changes: 5 additions & 8 deletions turplanlegger/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,25 +539,22 @@ def add_trip_note_reference(self, trip_id, note_id):
insert_ref = """
INSERT INTO trips_notes_references (trip_id, note_id)
VALUES (%(trip_id)s, %(note_id)s)
RETURNING *
"""
return self._insert(insert_ref, {'trip_id': trip_id, 'note_id': note_id})
return self._insert(insert_ref, {'trip_id': trip_id, 'note_id': note_id}, False)

def add_trip_item_list_reference(self, trip_id, item_list_id):
insert_ref = """
INSERT INTO trips_item_lists_references (trip_id, item_list_id)
VALUES (%(trip_id)s, %(item_list_id)s)
RETURNING *
"""
return self._insert(insert_ref, {'trip_id': trip_id, 'item_list_id': item_list_id})
return self._insert(insert_ref, {'trip_id': trip_id, 'item_list_id': item_list_id}, False)

def add_trip_route_reference(self, trip_id, route_id):
insert_ref = """
INSERT INTO trips_routes_references (trip_id, route_id)
VALUES (%(trip_id)s, %(route_id)s)
RETURNING *
"""
return self._insert(insert_ref, {'trip_id': trip_id, 'route_id': route_id})
return self._insert(insert_ref, {'trip_id': trip_id, 'route_id': route_id}, False)

def get_trip(self, trip_id: int, deleted=False):
select = 'SELECT * FROM trips WHERE id = %s'
Expand Down Expand Up @@ -682,14 +679,14 @@ def delete_trip_date(self, trip_date_id):
return self._updateone(update, {'id': trip_date_id}, returning=True)

# Helpers
def _insert(self, query, vars):
def _insert(self, query, vars, returning=True):
"""
Insert, with return.
"""
self._log('_insert', query, vars)
with self.conn.transaction():
self.cur.execute(query, vars)
return self.cur.fetchone()
return self.cur.fetchone() if returning else None

def _fetchone(self, query, vars):
"""
Expand Down
10 changes: 5 additions & 5 deletions turplanlegger/models/trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def delete(self) -> bool:
def update(self, updated_fields) -> None:
return db.update_trip(self, updated_fields)

def add_note_reference(self, note_id: int) -> 'Trip':
def add_note_reference(self, note_id: int) -> None:
"""Adds a note to the trip instance

Args:
Expand All @@ -157,9 +157,9 @@ def add_note_reference(self, note_id: int) -> 'Trip':
dict of notes from the database
"""
db.add_trip_note_reference(self.id, note_id)
self.notes = db.get_trip_notes(self.id)
self.notes = [item.note_id for item in db.get_trip_notes(self.id)]

def add_route_reference(self, route_id: int) -> 'Trip':
def add_route_reference(self, route_id: int) -> None:
"""Adds a route to the trip instance

Args:
Expand All @@ -169,11 +169,11 @@ def add_route_reference(self, route_id: int) -> 'Trip':
dict of routes from the database
"""
db.add_trip_route_reference(self.id, route_id)
self.routes = db.get_trip_routes(self.id)
self.routes = [item.route_id for item in db.get_trip_routes(self.id)]

def add_item_list_reference(self, item_list_id: int) -> 'Trip':
db.add_trip_item_list_reference(self.id, item_list_id)
self.routes = db.get_trip_item_lists(self.id)
self.item_lists = [item.item_list_id for item in db.get_trip_item_lists(self.id)]

@staticmethod
def update_trip_dates(dates: JSON, trip: 'Trip') -> 'TRIP_DATE_UPDATE_STATUS':
Expand Down
43 changes: 28 additions & 15 deletions turplanlegger/views/trips.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,20 @@ def add_note_to_trip(trip_id: int):
if not note:
raise ApiProblem('Failed to add note to trip', 'Note was not found', 404)

if Permission.verify(note.owner, note.permissions, g.user.id, AccessLevel.READ) is PermissionResult.NOT_FOUND:
raise ApiProblem('Failed to add note to trip', 'Note was not found', 404)
if note.id in trip.notes:
return jsonify(status='ok', count=1, trip=trip.serialize), 200

# I'm lazy, so I'm keeping this here until I fix private attribute for Note
# if note.private is True:
# note_perms = Permission.verify(note.owner, note.permissions, g.user.id, AccessLevel.READ)
# if note_perms is PermissionResult.NOT_FOUND:
# raise ApiProblem('Failed to add note to trip', 'Note was not found', 404)
if note.private is True:
note_perms = Permission.verify(note.owner, note.permissions, g.user.id, AccessLevel.READ)
if note_perms is PermissionResult.NOT_FOUND:
raise ApiProblem('Failed to add note to trip', 'Note was not found', 404)

try:
trip.add_note_reference(note.id)
except Exception as e:
raise ApiProblem('Failed to add note to trip', str(e), 500)

return jsonify(trip.serialize), 201
return jsonify(status='ok', count=1, trip=trip.serialize), 201


@api.route('/trips/<trip_id>/routes', methods=['PATCH'])
Expand All @@ -159,12 +158,15 @@ def add_route_to_trip(trip_id: int):
if Permission.verify(route.owner, route.permissions, g.user.id, AccessLevel.READ) is PermissionResult.NOT_FOUND:
raise ApiProblem('Failed to add route to trip', 'Route was not found', 404)

if route.id in trip.routes:
return jsonify(status='ok', count=1, trip=trip.serialize), 200

try:
trip.add_route_reference(route.id)
except Exception as e:
raise ApiProblem('Failed to add route to trip', str(e), 500)

return jsonify(trip.serialize), 201
return jsonify(status='ok', count=1, trip=trip.serialize), 201


@api.route('/trips/<trip_id>/item_lists', methods=['PATCH'])
Expand All @@ -174,21 +176,32 @@ def add_item_list_to_trip(trip_id: int):
if not trip:
raise ApiProblem('Failed to add item list to trip', 'Trip was not found', 404)

perms = Permission.verify(trip.owner, trip.permissions, g.user.id, AccessLevel.MODIFY)
if perms is PermissionResult.NOT_FOUND:
if trip.private is True:
raise ApiProblem('Trip not found', 'The requested trip was not found', 404)
raise ApiProblem('Insufficient permissions', 'Not sufficient permissions to modify the trip', 403)
if perms is PermissionResult.INSUFFICIENT_PERMISSIONS:
raise ApiProblem('Insufficient permissions', 'Not sufficient permissions to modify the trip', 403)

item_list = ItemList.find_item_list(request.json.get('item_list_id', None))
if not item_list:
raise ApiProblem('Failed to add item list to trip', 'Item list was not found', 404)

if (
Permission.verify(item_list.owner, item_list.permissions, g.user.id, AccessLevel.READ)
is PermissionResult.NOT_FOUND
):
raise ApiProblem('Failed to add item list to trip', 'Item list was not found', 404)
if item_list.private is True:
note_perms = Permission.verify(item_list.owner, item_list.permissions, g.user.id, AccessLevel.READ)
if note_perms is PermissionResult.NOT_FOUND:
raise ApiProblem('Failed to add item list to trip', 'Item list was not found', 404)

if item_list.id in trip.item_lists:
return jsonify(status='ok', count=1, trip=trip.serialize), 200

try:
trip.add_item_list_reference(item_list.id)
except Exception as e:
raise ApiProblem('Failed to add item list to trip', str(e), 500)

return jsonify(trip.serialize), 201
return jsonify(status='ok', count=1, trip=trip.serialize), 201


@api.route('/trips/<trip_id>/owner', methods=['PATCH'])
Expand Down