diff --git a/tools/server/server.cpp b/tools/server/server.cpp index f32f3c86aad2c..129d013ac75f7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2251,6 +2251,14 @@ struct server_context { slot.has_next_token = true; } + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + } + // check the limits if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { slot.stop = STOP_TYPE_LIMIT; diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index be93a6d31f410..2431ac70882d7 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -65,3 +65,21 @@ def test_ctx_shift_disabled_long_prompt(): assert res.status_code != 200 assert "error" in res.body assert "exceeds the available context size" in res.body["error"]["message"] + +def test_ctx_shift_disabled_stream(): + global server + server.disable_ctx_shift = True + server.start() + res = server.make_stream_request("POST", "/v1/completions", data={ + "n_predict": 256, + "prompt": "Once", + "stream": True, + }) + content = "" + for data in res: + choice = data["choices"][0] + if choice["finish_reason"] == "length": + assert len(content) > 0 + else: + assert choice["finish_reason"] is None + content += choice["text"]