diff --git a/flask_rest_jsonapi/schema.py b/flask_rest_jsonapi/schema.py index f7bff523..a4a8dd49 100644 --- a/flask_rest_jsonapi/schema.py +++ b/flask_rest_jsonapi/schema.py @@ -23,14 +23,35 @@ def compute_schema(schema_cls, default_kwargs, qs, include): schema_kwargs = default_kwargs schema_kwargs['include_data'] = tuple() + + # manage sparse fieldsets + only_arg = None + if schema_kwargs.get('only') is not None: + only_arg = set(schema_kwargs['only']) + if schema_cls.opts.type_ in qs.fields: + # Validation handled by QSManager class, safe to assume any fields we see here exist + sparse_fields = set(qs.fields[schema_cls.opts.type_]) + if only_arg is not None: + only_arg &= sparse_fields + else: + only_arg = sparse_fields + if only_arg is not None: + # make sure id field is in only parameter unless marshamllow will raise an Exception + only_arg.add('id') + schema_kwargs['only'] = only_arg + # collect sub-related_includes related_includes = {} - if include: + available_fields = ( + schema_kwargs['only'] + if 'only' in schema_kwargs + else schema_cls._declared_fields + ) for include_path in include: field = include_path.split('.')[0] - if field not in schema_cls._declared_fields: + if field not in available_fields: raise InvalidInclude("{} has no attribute {}".format(schema_cls.__name__, field)) elif not isinstance(schema_cls._declared_fields[field], Relationship): raise InvalidInclude("{} is not a relationship attribute of {}".format(field, schema_cls.__name__)) @@ -41,24 +62,9 @@ def compute_schema(schema_cls, default_kwargs, qs, include): if '.' in include_path: related_includes[field] += ['.'.join(include_path.split('.')[1:])] - # make sure id field is in only parameter unless marshamllow will raise an Exception - if schema_kwargs.get('only') is not None and 'id' not in schema_kwargs['only']: - schema_kwargs['only'] += ('id',) - # create base schema instance schema = schema_cls(**schema_kwargs) - # manage sparse fieldsets - if schema.opts.type_ in qs.fields: - tmp_only = set(schema.declared_fields.keys()) & set(qs.fields[schema.opts.type_]) - if schema.only: - tmp_only &= set(schema.only) - schema.only = tuple(tmp_only) - - # make sure again that id field is in only parameter unless marshamllow will raise an Exception - if schema.only is not None and 'id' not in schema.only: - schema.only += ('id',) - # manage compound documents if include: for include_path in include: diff --git a/tests/test_sqlalchemy_data_layer.py b/tests/test_sqlalchemy_data_layer.py index 3662ca8f..1ccf4729 100644 --- a/tests/test_sqlalchemy_data_layer.py +++ b/tests/test_sqlalchemy_data_layer.py @@ -623,6 +623,25 @@ def test_compute_schema_propagate_context(person_schema, computer_schema): schema = flask_rest_jsonapi.schema.compute_schema(person_schema, dict(context=dict(foo='bar')), qsm, ['computers']) assert schema.declared_fields['computers'].__dict__['_Relationship__schema'].__dict__['context'] == dict(foo='bar') +@pytest.mark.parametrize( + 'querystring,default_kwargs,expected_fields', + [ + ({}, {}, {'id', 'name', 'birth_date', 'computers', 'tags', 'single_tag'}), + ({}, {'only': ()}, {'id'}), + ({}, {'only': ('name',)}, {'id', 'name'}), + ({'fields[person]': 'name,tags'}, {}, {'id', 'name', 'tags'}), + ({'fields[person]': 'name,tags'}, {'only': ('name', 'birth_date')}, {'id', 'name'}), + ], +) +def test_compute_schema_sparse_fieldsets(person_schema, querystring, default_kwargs, expected_fields): + qs = QSManager(querystring, person_schema) + schema = flask_rest_jsonapi.schema.compute_schema(person_schema, default_kwargs, qs, []) + assert set(schema.dump_fields.keys()) == expected_fields + +def test_compute_schema_sparse_fieldset_related(person_schema, computer_schema): + qs = QSManager({'fields[computer]': 'owner'}, person_schema) + schema = flask_rest_jsonapi.schema.compute_schema(person_schema, {}, qs, ['computers']) + assert (schema.declared_fields['computers'].__dict__['_Relationship__schema'].dump_fields.keys()) == {'id', 'owner'} # test good cases def test_get_list(client, register_routes, person, person_2): @@ -630,7 +649,7 @@ def test_get_list(client, register_routes, person, person_2): querystring = urlencode({ 'page[number]': 1, 'page[size]': 1, - 'fields[person]': 'name,birth_date', + 'fields[person]': 'name,birth_date,computers', 'sort': '-name', 'include': 'computers.owner', 'filter': json.dumps( @@ -666,6 +685,14 @@ def test_get_list(client, register_routes, person, person_2): }) response = client.get('/persons' + '?' + querystring, content_type='application/vnd.api+json') assert response.status_code == 200, response.json['errors'] + assert response.json == { + 'data': [], + 'jsonapi': {'version': '1.0'}, + 'links': { + 'self': 'http://localhost/persons?page%5Bnumber%5D=1&page%5Bsize%5D=1&fields%5Bperson%5D=name%2Cbirth_date%2Ccomputers&sort=-name&include=computers.owner&filter=%5B%7B%22and%22%3A+%5B%7B%22name%22%3A+%22computers%22%2C+%22op%22%3A+%22any%22%2C+%22val%22%3A+%7B%22name%22%3A+%22serial%22%2C+%22op%22%3A+%22eq%22%2C+%22val%22%3A+%220000%22%7D%7D%2C+%7B%22or%22%3A+%5B%7B%22name%22%3A+%22name%22%2C+%22op%22%3A+%22like%22%2C+%22val%22%3A+%22%25test%25%22%7D%2C+%7B%22name%22%3A+%22name%22%2C+%22op%22%3A+%22like%22%2C+%22val%22%3A+%22%25test2%25%22%7D%5D%7D%5D%7D%5D' + }, + 'meta': {'count': 0}, + }, 'no items match filter' def test_get_list_with_simple_filter(client, register_routes, person, person_2): @@ -678,6 +705,21 @@ def test_get_list_with_simple_filter(client, register_routes, person, person_2): }) response = client.get('/persons' + '?' + querystring, content_type='application/vnd.api+json') assert response.status_code == 200, response.json['errors'] + assert response.json == { + 'data': [ + { + 'attributes': {'birth_date': None, 'name': 'test'}, + 'id': '1', + 'links': {'self': '/persons/1'}, + 'type': 'person', + } + ], + 'jsonapi': {'version': '1.0'}, + 'links': { + 'self': 'http://localhost/persons?page%5Bnumber%5D=1&page%5Bsize%5D=1&fields%5Bperson%5D=name%2Cbirth_date&sort=-name&filter%5Bname%5D=test' + }, + 'meta': {'count': 1}, + } def test_get_list_disable_pagination(client, register_routes):