diff --git a/tests/test_trip.py b/tests/test_trip.py index c0fb26d..798b71b 100644 --- a/tests/test_trip.py +++ b/tests/test_trip.py @@ -213,6 +213,41 @@ def test_create_trip_add_note(self): self.assertEqual(data['trip']['notes'], [note_id]) + def test_create_trip_add_existing_note(self): + response = self.client.post('/trips', data=json.dumps(self.trip), headers=self.headers_json) + self.assertEqual(response.status_code, 201) + data = json.loads(response.data.decode('utf-8')) + trip_id = data['id'] + + # Create note + response = self.client.post('/notes', data=json.dumps(self.note), headers=self.headers_json) + self.assertEqual(response.status_code, 201) + data = json.loads(response.data.decode('utf-8')) + note_id = data['id'] + + # Add note to trip + response = self.client.patch( + f'/trips/{trip_id}/notes', data=json.dumps({'note_id': note_id}), headers=self.headers_json + ) + self.assertEqual(response.status_code, 201) + + response = self.client.get(f'/trips/{trip_id}', headers=self.headers) + self.assertEqual(response.status_code, 200) + data = json.loads(response.data.decode('utf-8')) + + self.assertEqual(len(data['trip']['notes']), 1) + self.assertEqual(data['trip']['notes'], [note_id]) + + # Add note again to trip + response = self.client.patch( + f'/trips/{trip_id}/notes', data=json.dumps({'note_id': note_id}), headers=self.headers_json + ) + self.assertEqual(response.status_code, 200) + data = json.loads(response.data.decode('utf-8')) + + self.assertEqual(len(data['trip']['notes']), 1) + self.assertEqual(data['trip']['notes'], [note_id]) + def test_create_trip_add_route(self): response = self.client.post('/trips', data=json.dumps(self.trip), headers=self.headers_json) self.assertEqual(response.status_code, 201) @@ -232,10 +267,39 @@ def test_create_trip_add_route(self): self.assertEqual(response.status_code, 201) data = json.loads(response.data.decode('utf-8')) - response = self.client.get(f'/trips/{trip_id}', headers=self.headers) + self.assertEqual(len(data['trip']['routes']), 1) + self.assertEqual(data['trip']['routes'], [route_id]) + + def test_create_trip_add_existing_route(self): + response = self.client.post('/trips', data=json.dumps(self.trip), headers=self.headers_json) + self.assertEqual(response.status_code, 201) + data = json.loads(response.data.decode('utf-8')) + trip_id = data['id'] + + # Create route + response = self.client.post('/routes', data=json.dumps(self.route), headers=self.headers_json) + self.assertEqual(response.status_code, 201) + data = json.loads(response.data.decode('utf-8')) + route_id = data['id'] + + # Add route to trip + response = self.client.patch( + f'/trips/{trip_id}/routes', data=json.dumps({'route_id': route_id}), headers=self.headers_json + ) + self.assertEqual(response.status_code, 201) + data = json.loads(response.data.decode('utf-8')) + + self.assertEqual(len(data['trip']['routes']), 1) + self.assertEqual(data['trip']['routes'], [route_id]) + + # Add existing route to trip + response = self.client.patch( + f'/trips/{trip_id}/routes', data=json.dumps({'route_id': route_id}), headers=self.headers_json + ) self.assertEqual(response.status_code, 200) data = json.loads(response.data.decode('utf-8')) + self.assertEqual(len(data['trip']['routes']), 1) self.assertEqual(data['trip']['routes'], [route_id]) def test_create_trip_add_item_list(self): @@ -260,10 +324,19 @@ def test_create_trip_add_item_list(self): self.assertEqual(response.status_code, 201) data = json.loads(response.data.decode('utf-8')) - response = self.client.get(f'/trips/{trip_id}', headers=self.headers) + self.assertEqual(len(data['trip']['item_lists']), 1) + self.assertEqual(data['trip']['item_lists'], [item_list_id]) + + # Add existing item_list to trip + response = self.client.patch( + f'/trips/{trip_id}/item_lists', + data=json.dumps({'item_list_id': item_list_id}), + headers=self.headers_json, + ) self.assertEqual(response.status_code, 200) data = json.loads(response.data.decode('utf-8')) + self.assertEqual(len(data['trip']['item_lists']), 1) self.assertEqual(data['trip']['item_lists'], [item_list_id]) def test_change_trip_owner(self): diff --git a/turplanlegger/database/base.py b/turplanlegger/database/base.py index ffa8924..9489a4d 100644 --- a/turplanlegger/database/base.py +++ b/turplanlegger/database/base.py @@ -539,25 +539,22 @@ def add_trip_note_reference(self, trip_id, note_id): insert_ref = """ INSERT INTO trips_notes_references (trip_id, note_id) VALUES (%(trip_id)s, %(note_id)s) - RETURNING * """ - return self._insert(insert_ref, {'trip_id': trip_id, 'note_id': note_id}) + return self._insert(insert_ref, {'trip_id': trip_id, 'note_id': note_id}, False) def add_trip_item_list_reference(self, trip_id, item_list_id): insert_ref = """ INSERT INTO trips_item_lists_references (trip_id, item_list_id) VALUES (%(trip_id)s, %(item_list_id)s) - RETURNING * """ - return self._insert(insert_ref, {'trip_id': trip_id, 'item_list_id': item_list_id}) + return self._insert(insert_ref, {'trip_id': trip_id, 'item_list_id': item_list_id}, False) def add_trip_route_reference(self, trip_id, route_id): insert_ref = """ INSERT INTO trips_routes_references (trip_id, route_id) VALUES (%(trip_id)s, %(route_id)s) - RETURNING * """ - return self._insert(insert_ref, {'trip_id': trip_id, 'route_id': route_id}) + return self._insert(insert_ref, {'trip_id': trip_id, 'route_id': route_id}, False) def get_trip(self, trip_id: int, deleted=False): select = 'SELECT * FROM trips WHERE id = %s' @@ -682,14 +679,14 @@ def delete_trip_date(self, trip_date_id): return self._updateone(update, {'id': trip_date_id}, returning=True) # Helpers - def _insert(self, query, vars): + def _insert(self, query, vars, returning=True): """ Insert, with return. """ self._log('_insert', query, vars) with self.conn.transaction(): self.cur.execute(query, vars) - return self.cur.fetchone() + return self.cur.fetchone() if returning else None def _fetchone(self, query, vars): """ diff --git a/turplanlegger/models/trip.py b/turplanlegger/models/trip.py index 74d0e4d..fa85cdd 100644 --- a/turplanlegger/models/trip.py +++ b/turplanlegger/models/trip.py @@ -147,7 +147,7 @@ def delete(self) -> bool: def update(self, updated_fields) -> None: return db.update_trip(self, updated_fields) - def add_note_reference(self, note_id: int) -> 'Trip': + def add_note_reference(self, note_id: int) -> None: """Adds a note to the trip instance Args: @@ -157,9 +157,9 @@ def add_note_reference(self, note_id: int) -> 'Trip': dict of notes from the database """ db.add_trip_note_reference(self.id, note_id) - self.notes = db.get_trip_notes(self.id) + self.notes = [item.note_id for item in db.get_trip_notes(self.id)] - def add_route_reference(self, route_id: int) -> 'Trip': + def add_route_reference(self, route_id: int) -> None: """Adds a route to the trip instance Args: @@ -169,11 +169,11 @@ def add_route_reference(self, route_id: int) -> 'Trip': dict of routes from the database """ db.add_trip_route_reference(self.id, route_id) - self.routes = db.get_trip_routes(self.id) + self.routes = [item.route_id for item in db.get_trip_routes(self.id)] def add_item_list_reference(self, item_list_id: int) -> 'Trip': db.add_trip_item_list_reference(self.id, item_list_id) - self.routes = db.get_trip_item_lists(self.id) + self.item_lists = [item.item_list_id for item in db.get_trip_item_lists(self.id)] @staticmethod def update_trip_dates(dates: JSON, trip: 'Trip') -> 'TRIP_DATE_UPDATE_STATUS': diff --git a/turplanlegger/views/trips.py b/turplanlegger/views/trips.py index e7e9b0a..2f5c96b 100644 --- a/turplanlegger/views/trips.py +++ b/turplanlegger/views/trips.py @@ -120,21 +120,20 @@ def add_note_to_trip(trip_id: int): if not note: raise ApiProblem('Failed to add note to trip', 'Note was not found', 404) - if Permission.verify(note.owner, note.permissions, g.user.id, AccessLevel.READ) is PermissionResult.NOT_FOUND: - raise ApiProblem('Failed to add note to trip', 'Note was not found', 404) + if note.id in trip.notes: + return jsonify(status='ok', count=1, trip=trip.serialize), 200 - # I'm lazy, so I'm keeping this here until I fix private attribute for Note - # if note.private is True: - # note_perms = Permission.verify(note.owner, note.permissions, g.user.id, AccessLevel.READ) - # if note_perms is PermissionResult.NOT_FOUND: - # raise ApiProblem('Failed to add note to trip', 'Note was not found', 404) + if note.private is True: + note_perms = Permission.verify(note.owner, note.permissions, g.user.id, AccessLevel.READ) + if note_perms is PermissionResult.NOT_FOUND: + raise ApiProblem('Failed to add note to trip', 'Note was not found', 404) try: trip.add_note_reference(note.id) except Exception as e: raise ApiProblem('Failed to add note to trip', str(e), 500) - return jsonify(trip.serialize), 201 + return jsonify(status='ok', count=1, trip=trip.serialize), 201 @api.route('/trips//routes', methods=['PATCH']) @@ -159,12 +158,15 @@ def add_route_to_trip(trip_id: int): if Permission.verify(route.owner, route.permissions, g.user.id, AccessLevel.READ) is PermissionResult.NOT_FOUND: raise ApiProblem('Failed to add route to trip', 'Route was not found', 404) + if route.id in trip.routes: + return jsonify(status='ok', count=1, trip=trip.serialize), 200 + try: trip.add_route_reference(route.id) except Exception as e: raise ApiProblem('Failed to add route to trip', str(e), 500) - return jsonify(trip.serialize), 201 + return jsonify(status='ok', count=1, trip=trip.serialize), 201 @api.route('/trips//item_lists', methods=['PATCH']) @@ -174,21 +176,32 @@ def add_item_list_to_trip(trip_id: int): if not trip: raise ApiProblem('Failed to add item list to trip', 'Trip was not found', 404) + perms = Permission.verify(trip.owner, trip.permissions, g.user.id, AccessLevel.MODIFY) + if perms is PermissionResult.NOT_FOUND: + if trip.private is True: + raise ApiProblem('Trip not found', 'The requested trip was not found', 404) + raise ApiProblem('Insufficient permissions', 'Not sufficient permissions to modify the trip', 403) + if perms is PermissionResult.INSUFFICIENT_PERMISSIONS: + raise ApiProblem('Insufficient permissions', 'Not sufficient permissions to modify the trip', 403) + item_list = ItemList.find_item_list(request.json.get('item_list_id', None)) if not item_list: raise ApiProblem('Failed to add item list to trip', 'Item list was not found', 404) - if ( - Permission.verify(item_list.owner, item_list.permissions, g.user.id, AccessLevel.READ) - is PermissionResult.NOT_FOUND - ): - raise ApiProblem('Failed to add item list to trip', 'Item list was not found', 404) + if item_list.private is True: + note_perms = Permission.verify(item_list.owner, item_list.permissions, g.user.id, AccessLevel.READ) + if note_perms is PermissionResult.NOT_FOUND: + raise ApiProblem('Failed to add item list to trip', 'Item list was not found', 404) + + if item_list.id in trip.item_lists: + return jsonify(status='ok', count=1, trip=trip.serialize), 200 + try: trip.add_item_list_reference(item_list.id) except Exception as e: raise ApiProblem('Failed to add item list to trip', str(e), 500) - return jsonify(trip.serialize), 201 + return jsonify(status='ok', count=1, trip=trip.serialize), 201 @api.route('/trips//owner', methods=['PATCH'])