package framework import ( "bytes" "context" "encoding/json" "errors" "io/ioutil" "net/http" "strconv" "sync" "time" ) type Context struct { request *http.Request responseWriter http.ResponseWriter ctx context.Context handler ControllerHandler hasTimeout bool // 是否超时标记位 writerMux *sync.Mutex // 写保护机制 } func NewContext(r *http.Request, w http.ResponseWriter) *Context { return &Context{ request: r, responseWriter: w, ctx: r.Context(), writerMux: &sync.Mutex{}, } } // #region implement context.Context func (ctx *Context) BaseContext() context.Context { return ctx.request.Context() } func (ctx *Context) Deadline() (deadline time.Time, ok bool) { return ctx.request.Context().Deadline() } func (ctx *Context) Done() <-chan struct{} { return ctx.BaseContext().Done() } func (ctx *Context) Err() error { return ctx.BaseContext().Err() } func (ctx *Context) Value(key interface{}) interface{} { return ctx.BaseContext().Value(key) } // Base func (ctx *Context) WriterMux() *sync.Mutex { return ctx.writerMux } func (ctx *Context) GetRequest() *http.Request { return ctx.request } func (ctx *Context) GetResponse() http.ResponseWriter { return ctx.responseWriter } func (ctx *Context) SetHasTimeout() { ctx.hasTimeout = true } func (ctx *Context) HasTimeout() bool { return ctx.hasTimeout } // #region query url func (ctx *Context) QueryInt(key string, def int) int { params := ctx.QueryAll() if vals, ok := params[key]; ok { valsLen := len(vals) if valsLen > 0 { intVal, err := strconv.Atoi(vals[valsLen-1]) if err != nil { return def } return intVal } } return def } func (ctx *Context) QueryString(key string, def string) string { params := ctx.QueryAll() if vals, ok := params[key]; ok { valsLen := len(vals) if valsLen > 0 { return vals[valsLen-1] } } return def } func (ctx *Context) QueryArray(key string, def []string) []string { params := ctx.QueryAll() if vals, ok := params[key]; ok { return vals } return def } // 查询所有url后面的参数 func (ctx *Context) QueryAll() map[string][]string { if ctx.request != nil { return ctx.request.URL.Query() } return map[string][]string{} } // #region form post func (ctx *Context) FormInt(key string, def int) int { params := ctx.FormAll() if vals, ok := params[key]; ok { valsLen := len(vals) if valsLen > 0 { intval, err := strconv.Atoi(vals[valsLen-1]) if err != nil { return def } return intval } } return def } func (ctx *Context) FormString(key string, def string) string { params := ctx.FormAll() if vals, ok := params[key]; ok { valsLen := len(vals) if valsLen > 0 { return vals[valsLen-1] } } return def } func (ctx *Context) FormArray(key string, def []string) []string { params := ctx.FormAll() if vals, ok := params[key]; ok { return vals } return def } func (ctx *Context) FormAll() map[string][]string { if ctx.request != nil { return ctx.request.PostForm } return map[string][]string{} } // #region application/json post func (ctx *Context) BindJson(obj interface{}) error { if ctx.request != nil { body, err := ioutil.ReadAll(ctx.request.Body) if err != nil { return err } ctx.request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) err = json.Unmarshal(body, obj) if err != nil { return err } } else { return errors.New("ctx.request empty") } return nil } // #endregion // #region response func (ctx *Context) Json(status int, obj interface{}) error { if ctx.HasTimeout() { return nil } ctx.responseWriter.Header().Set("Content-Type", "application/json") ctx.responseWriter.WriteHeader(status) byt, err := json.Marshal(obj) if err != nil { ctx.responseWriter.WriteHeader(500) return err } _, err = ctx.responseWriter.Write(byt) if err != nil { return err } return nil } func (ctx *Context) HTML(status int, obj interface{}, template string) error { return nil } func (ctx *Context) Text(status int, obj string) error { return nil }