diff options
Diffstat (limited to 'api4')
-rw-r--r-- | api4/post.go | 55 | ||||
-rw-r--r-- | api4/post_test.go | 119 |
2 files changed, 164 insertions, 10 deletions
diff --git a/api4/post.go b/api4/post.go index a43f2f20f..329241139 100644 --- a/api4/post.go +++ b/api4/post.go @@ -60,24 +60,67 @@ func getPostsForChannel(c *Context, w http.ResponseWriter, r *http.Request) { return } + afterPost := r.URL.Query().Get("after") + beforePost := r.URL.Query().Get("before") + sinceString := r.URL.Query().Get("since") + + var since int64 + var parseError error + + if len(sinceString) > 0 { + since, parseError = strconv.ParseInt(sinceString, 10, 64) + if parseError != nil { + c.SetInvalidParam("since") + return + } + } + if !app.SessionHasPermissionToChannel(c.Session, c.Params.ChannelId, model.PERMISSION_READ_CHANNEL) { c.SetPermissionError(model.PERMISSION_READ_CHANNEL) return } - etag := app.GetPostsEtag(c.Params.ChannelId) + var list *model.PostList + var err *model.AppError + etag := "" - if HandleEtag(etag, "Get Posts", w, r) { - return + if since > 0 { + list, err = app.GetPostsSince(c.Params.ChannelId, since) + } else if len(afterPost) > 0 { + etag = app.GetPostsEtag(c.Params.ChannelId) + + if HandleEtag(etag, "Get Posts After", w, r) { + return + } + + list, err = app.GetPostsAfterPost(c.Params.ChannelId, afterPost, c.Params.Page, c.Params.PerPage) + } else if len(beforePost) > 0 { + etag = app.GetPostsEtag(c.Params.ChannelId) + + if HandleEtag(etag, "Get Posts Before", w, r) { + return + } + + list, err = app.GetPostsBeforePost(c.Params.ChannelId, beforePost, c.Params.Page, c.Params.PerPage) + } else { + etag = app.GetPostsEtag(c.Params.ChannelId) + + if HandleEtag(etag, "Get Posts", w, r) { + return + } + + list, err = app.GetPostsPage(c.Params.ChannelId, c.Params.Page, c.Params.PerPage) } - if list, err := app.GetPostsPage(c.Params.ChannelId, c.Params.Page, c.Params.PerPage); err != nil { + if err != nil { c.Err = err return - } else { + } + + if len(etag) > 0 { w.Header().Set(model.HEADER_ETAG_SERVER, etag) - w.Write([]byte(list.ToJson())) } + w.Write([]byte(list.ToJson())) } func getPost(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/api4/post_test.go b/api4/post_test.go index 5c5c98d3d..9e0880004 100644 --- a/api4/post_test.go +++ b/api4/post_test.go @@ -174,9 +174,12 @@ func TestGetPostsForChannel(t *testing.T) { post1 := th.CreatePost() post2 := th.CreatePost() - post3 := th.CreatePost() - post4 := &model.Post{ChannelId: th.BasicChannel.Id, Message: "a" + model.NewId() + "a", RootId: post1.Id} - post4, _ = Client.CreatePost(post4) + post3 := &model.Post{ChannelId: th.BasicChannel.Id, Message: "a" + model.NewId() + "a", RootId: post1.Id} + post3, _ = Client.CreatePost(post3) + + time := model.GetMillis() + + post4 := th.CreatePost() posts, resp := Client.GetPostsForChannel(th.BasicChannel.Id, 0, 60, "") CheckNoError(t, resp) @@ -207,7 +210,7 @@ func TestGetPostsForChannel(t *testing.T) { t.Fatal("wrong number returned") } - if _, ok := posts.Posts[post4.Id]; !ok { + if _, ok := posts.Posts[post3.Id]; !ok { t.Fatal("missing comment") } @@ -229,6 +232,30 @@ func TestGetPostsForChannel(t *testing.T) { t.Fatal("should be no posts") } + post5 := th.CreatePost() + + posts, resp = Client.GetPostsSince(th.BasicChannel.Id, time) + CheckNoError(t, resp) + + found := make([]bool, 2) + for _, p := range posts.Posts { + if p.CreateAt < time { + t.Fatal("bad create at for post returned") + } + + if p.Id == post4.Id { + found[0] = true + } else if p.Id == post5.Id { + found[1] = true + } + } + + for _, f := range found { + if !f { + t.Fatal("missing post") + } + } + _, resp = Client.GetPostsForChannel("", 0, 60, "") CheckBadRequestStatus(t, resp) @@ -246,6 +273,90 @@ func TestGetPostsForChannel(t *testing.T) { CheckNoError(t, resp) } +func TestGetPostsAfterAndBefore(t *testing.T) { + th := Setup().InitBasic() + defer TearDown() + Client := th.Client + + post1 := th.CreatePost() + post2 := th.CreatePost() + post3 := th.CreatePost() + post4 := th.CreatePost() + post5 := th.CreatePost() + + posts, resp := Client.GetPostsBefore(th.BasicChannel.Id, post3.Id, 0, 100, "") + CheckNoError(t, resp) + + found := make([]bool, 2) + for _, p := range posts.Posts { + if p.Id == post1.Id { + found[0] = true + } else if p.Id == post2.Id { + found[1] = true + } + + if p.Id == post4.Id || p.Id == post5.Id { + t.Fatal("returned posts after") + } + } + + for _, f := range found { + if !f { + t.Fatal("missing post") + } + } + + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post3.Id, 1, 1, "") + CheckNoError(t, resp) + + if len(posts.Posts) != 1 { + t.Fatal("too many posts returned") + } + + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, "junk", 1, 1, "") + CheckNoError(t, resp) + + if len(posts.Posts) != 0 { + t.Fatal("should have no posts") + } + + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post3.Id, 0, 100, "") + CheckNoError(t, resp) + + found = make([]bool, 2) + for _, p := range posts.Posts { + if p.Id == post4.Id { + found[0] = true + } else if p.Id == post5.Id { + found[1] = true + } + + if p.Id == post1.Id || p.Id == post2.Id { + t.Fatal("returned posts before") + } + } + + for _, f := range found { + if !f { + t.Fatal("missing post") + } + } + + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post3.Id, 1, 1, "") + CheckNoError(t, resp) + + if len(posts.Posts) != 1 { + t.Fatal("too many posts returned") + } + + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, "junk", 1, 1, "") + CheckNoError(t, resp) + + if len(posts.Posts) != 0 { + t.Fatal("should have no posts") + } +} + func TestGetPost(t *testing.T) { th := Setup().InitBasic().InitSystemAdmin() defer TearDown() |