Package restkit :: Package filters :: Module oauth2
[hide private]
[frames] | no frames]

Source Code for Module restkit.filters.oauth2

  1  # -*- coding: utf-8 - 
  2  # 
  3  # This file is part of restkit released under the MIT license.  
  4  # See the NOTICE for more information. 
  5   
  6  import re 
  7  import urlparse 
  8  try: 
  9      from urlparse import parse_qsl 
 10  except ImportError: 
 11      from cgi import parse_qsl 
 12   
 13  from urlparse import urlunparse 
 14       
 15  from restkit.util import replace_header 
 16  from restkit.util.oauth2 import Consumer, Request, SignatureMethod_HMAC_SHA1,\ 
 17  Token 
 18   
19 -def validate_consumer(consumer):
20 """ validate a consumer agains oauth2.Consumer object """ 21 if not isinstance(consumer, Consumer): 22 raise ValueError("Invalid consumer.") 23 return consumer
24
25 -def validate_token(token):
26 """ validate a token agains oauth2.Token object """ 27 if token is not None and not isinstance(token, Token): 28 raise ValueError("Invalid token.") 29 return token
30 31
32 -class OAuthFilter(object):
33
34 - def __init__(self, path, consumer, token=None, method=None):
35 """ Init OAuthFilter 36 37 :param path: path or regexp. * mean all path on wicth oauth can be 38 applied. 39 :param consumer: oauth consumer, instance of oauth2.Consumer 40 :param token: oauth token, instance of oauth2.Token 41 :param method: oauth signature method 42 43 token and method signature are optionnals. Consumer should be an 44 instance of `oauth2.Consumer`, token an instance of `oauth2.Toke` 45 signature method an instance of `oauth2.SignatureMethod`. 46 47 """ 48 49 if path.endswith('*'): 50 self.match = re.compile("%s.*" % path.rsplit('*', 1)[0]) 51 else: 52 self.match = re.compile("%s$" % path) 53 self.consumer = validate_consumer(consumer) 54 self.token = validate_token(token) 55 self.method = method or SignatureMethod_HMAC_SHA1()
56
57 - def on_path(self, req):
58 path = req.uri.path or "/" 59 return (self.match.match(path) is not None)
60
61 - def on_request(self, req, tries):
62 if tries < 2: 63 return 64 65 if not self.on_path(req): 66 return 67 68 headers = dict(req.headers) 69 params = {} 70 form = False 71 if req.body and req.body is not None: 72 ctype = headers.get('Content-Type') 73 if ctype is not None and \ 74 ctype.startswith('application/x-www-form-urlencoded'): 75 # we are in a form try to get oauth params from here 76 form = True 77 params = dict(parse_qsl(req.body)) 78 79 # update params from quey parameters 80 params.update(parse_qsl(req.uri.query)) 81 82 raw_url = urlunparse((req.uri.scheme, req.uri.netloc, 83 req.uri.path, '', '', '')) 84 85 oauth_req = Request.from_consumer_and_token(self.consumer, 86 token=self.token, http_method=req.method, 87 http_url=raw_url, parameters=params) 88 89 oauth_req.sign_request(self.method, self.consumer, self.token) 90 91 if form: 92 req.body = oauth_req.to_postdata() 93 req.headers = replace_header('Content-Length', len(req.body), 94 req.headers) 95 elif req.method in ('GET', 'HEAD'): 96 req.url = req.final_url = oauth_req.to_url() 97 req.uri = urlparse.urlparse(req.url) 98 else: 99 oauth_headers = oauth_req.to_header() 100 for k, v in list(oauth_headers.items()): 101 if not isinstance(v, basestring): 102 v = str(v) 103 req.headers.append((k.title(), v))
104