httpc.lua 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. local skynet = require "skynet"
  2. local socket = require "http.sockethelper"
  3. local internal = require "http.internal"
  4. local dns = require "skynet.dns"
  5. local string = string
  6. local table = table
  7. local httpc = {}
  8. local async_dns
  9. function httpc.dns(server,port)
  10. async_dns = true
  11. dns.server(server,port)
  12. end
  13. local function check_protocol(host)
  14. local protocol = host:match("^[Hh][Tt][Tt][Pp][Ss]?://")
  15. if protocol then
  16. host = string.gsub(host, "^"..protocol, "")
  17. protocol = string.lower(protocol)
  18. if protocol == "https://" then
  19. return "https", host
  20. elseif protocol == "http://" then
  21. return "http", host
  22. else
  23. error(string.format("Invalid protocol: %s", protocol))
  24. end
  25. else
  26. return "http", host
  27. end
  28. end
  29. local SSLCTX_CLIENT = nil
  30. local function gen_interface(protocol, fd, hostname)
  31. if protocol == "http" then
  32. return {
  33. init = nil,
  34. close = nil,
  35. read = socket.readfunc(fd),
  36. write = socket.writefunc(fd),
  37. readall = function ()
  38. return socket.readall(fd)
  39. end,
  40. }
  41. elseif protocol == "https" then
  42. local tls = require "http.tlshelper"
  43. SSLCTX_CLIENT = SSLCTX_CLIENT or tls.newctx()
  44. local tls_ctx = tls.newtls("client", SSLCTX_CLIENT, hostname)
  45. return {
  46. init = tls.init_requestfunc(fd, tls_ctx),
  47. close = tls.closefunc(tls_ctx),
  48. read = tls.readfunc(fd, tls_ctx),
  49. write = tls.writefunc(fd, tls_ctx),
  50. readall = tls.readallfunc(fd, tls_ctx),
  51. }
  52. else
  53. error(string.format("Invalid protocol: %s", protocol))
  54. end
  55. end
  56. local function connect(host, timeout)
  57. local protocol
  58. protocol, host = check_protocol(host)
  59. local hostaddr, port = host:match"([^:]+):?(%d*)$"
  60. if port == "" then
  61. port = protocol=="http" and 80 or protocol=="https" and 443
  62. else
  63. port = tonumber(port)
  64. end
  65. local hostname
  66. if not hostaddr:match(".*%d+$") then
  67. hostname = hostaddr
  68. if async_dns then
  69. hostaddr = dns.resolve(hostname)
  70. end
  71. end
  72. local fd = socket.connect(hostaddr, port, timeout)
  73. if not fd then
  74. error(string.format("%s connect error host:%s, port:%s, timeout:%s", protocol, hostaddr, port, timeout))
  75. end
  76. -- print("protocol hostname port", protocol, hostname, port)
  77. local interface = gen_interface(protocol, fd, hostname)
  78. if timeout then
  79. skynet.timeout(timeout, function()
  80. if not interface.finish then
  81. socket.shutdown(fd) -- shutdown the socket fd, need close later.
  82. end
  83. end)
  84. end
  85. if interface.init then
  86. interface.init()
  87. end
  88. return fd, interface, host
  89. end
  90. local function close_interface(interface, fd)
  91. interface.finish = true
  92. socket.close(fd)
  93. if interface.close then
  94. interface.close()
  95. interface.close = nil
  96. end
  97. end
  98. function httpc.request(method, hostname, url, recvheader, header, content)
  99. local fd, interface, host = connect(hostname, httpc.timeout)
  100. local ok , statuscode, body , header = pcall(internal.request, interface, method, host, url, recvheader, header, content)
  101. if ok then
  102. ok, body = pcall(internal.response, interface, statuscode, body, header)
  103. end
  104. close_interface(interface, fd)
  105. if ok then
  106. return statuscode, body
  107. else
  108. error(body or statuscode)
  109. end
  110. end
  111. function httpc.head(hostname, url, recvheader, header, content)
  112. local fd, interface, host = connect(hostname, httpc.timeout)
  113. local ok , statuscode = pcall(internal.request, interface, "HEAD", host, url, recvheader, header, content)
  114. close_interface(interface, fd)
  115. if ok then
  116. return statuscode
  117. else
  118. error(statuscode)
  119. end
  120. end
  121. function httpc.request_stream(method, hostname, url, recvheader, header, content)
  122. local fd, interface, host = connect(hostname, httpc.timeout)
  123. local ok , statuscode, body , header = pcall(internal.request, interface, method, host, url, recvheader, header, content)
  124. interface.finish = true -- don't shutdown fd in timeout
  125. local function close_fd()
  126. close_interface(interface, fd)
  127. end
  128. if not ok then
  129. close_fd()
  130. error(statuscode)
  131. end
  132. -- todo: stream support timeout
  133. local stream = internal.response_stream(interface, statuscode, body, header)
  134. stream._onclose = close_fd
  135. return stream
  136. end
  137. function httpc.get(...)
  138. return httpc.request("GET", ...)
  139. end
  140. local function escape(s)
  141. return (string.gsub(s, "([^A-Za-z0-9_])", function(c)
  142. return string.format("%%%02X", string.byte(c))
  143. end))
  144. end
  145. function httpc.post(host, url, form, recvheader)
  146. local header = {
  147. ["content-type"] = "application/x-www-form-urlencoded"
  148. }
  149. local body = {}
  150. for k,v in pairs(form) do
  151. table.insert(body, string.format("%s=%s",escape(k),escape(v)))
  152. end
  153. return httpc.request("POST", host, url, recvheader, header, table.concat(body , "&"))
  154. end
  155. return httpc