From c2f72acc3335dfaa11fb4b8df0d5cce538db965a Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 25 Dec 2024 17:02:24 +0800 Subject: [PATCH] api: return not found when region doesn't exist (#8869) close tikv/pd#8868 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/scheduling/server/apis/v1/api.go | 4 ++++ server/api/region.go | 12 ++++++++++++ server/api/region_test.go | 8 +++++++- tests/integrations/mcs/scheduling/api_test.go | 3 +++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 3d2d0005a24..535fa79ee0c 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -1476,6 +1476,10 @@ func getRegionByID(c *gin.Context) { c.String(http.StatusBadRequest, err.Error()) return } + if regionID == 0 { + c.String(http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs().Error()) + return + } regionInfo := svr.GetBasicCluster().GetRegion(regionID) if regionInfo == nil { c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) diff --git a/server/api/region.go b/server/api/region.go index afc32d2e762..a439cbfb349 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -67,8 +67,16 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } + if regionID == 0 { + h.rd.JSON(w, http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs()) + return + } regionInfo := rc.GetRegion(regionID) + if regionInfo == nil { + h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) + return + } b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) @@ -101,6 +109,10 @@ func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { } regionInfo := rc.GetRegionByKey(paramsByte[0]) + if regionInfo == nil { + h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs().Error()) + return + } b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) diff --git a/server/api/region_test.go b/server/api/region_test.go index ae91b41ef5e..4e0929636e8 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -80,7 +80,11 @@ func (suite *regionTestSuite) TestRegion() { r.UpdateBuckets(buckets, r.GetBuckets()) re := suite.Require() mustRegionHeartbeat(re, suite.svr, r) - url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) + url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, 0) + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest))) + url = fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, 2333) + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound))) + url = fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) r1 := &response.RegionInfo{} r1m := make(map[string]any) re.NoError(tu.ReadGetJSON(re, testDialClient, url, r1)) @@ -96,6 +100,8 @@ func (suite *regionTestSuite) TestRegion() { re.Equal(core.HexRegionKeyStr([]byte("a")), keys[0].(string)) re.Equal(core.HexRegionKeyStr([]byte("b")), keys[1].(string)) + url = fmt.Sprintf("%s/region/key/%s", suite.urlPrefix, "c") + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound))) url = fmt.Sprintf("%s/region/key/%s", suite.urlPrefix, "a") r2 := &response.RegionInfo{} re.NoError(tu.ReadGetJSON(re, testDialClient, url, r2)) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 14b867a587d..abace06bb78 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -714,6 +714,9 @@ func (suite *apiTestSuite) checkRegions(cluster *tests.TestCluster) { err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3., resp["count"]) + urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/0", scheServerAddr) + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, + testutil.Status(re, http.StatusBadRequest)) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/233", scheServerAddr) testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found"))