Welcome to the Chocolatey Community Package Repository! The packages found in this section of the site are provided, maintained, and moderated by the community.
Moderation
Every version of each package undergoes a rigorous moderation process before it goes live that typically includes:
- Security, consistency, and quality checking
- Installation testing
- Virus checking through VirusTotal
- Human moderators who give final review and sign off
More detail at Security and Moderation.
Organizational Use
If you are an organization using Chocolatey, we want your experience to be fully reliable. Due to the nature of this publicly offered repository, reliability cannot be guaranteed. Packages offered here are subject to distribution rights, which means they may need to reach out further to the internet to the official locations to download files at runtime.
Fortunately, distribution rights do not apply for internal use. With any edition of Chocolatey (including the free open source edition), you can host your own packages and cache or internalize existing community packages.
Disclaimer
Your use of the packages on this site means you understand they are not supported or guaranteed in any way. Learn more...
- Passing
- Failing
- Pending
- Unknown / Exempted

Downloads:
16,958
Downloads of v 2.7.11:
1,592
Last Update:
17 Nov 2020
Package Maintainer(s):
Software Author(s):
- IronPython Contributors
- Microsoft
Tags:
ironpython python dynamic dlr- Software Specific:
- Software Site
- Software License
- Software Mailing List
- Package Specific:
- Package Source
- Package outdated?
- Package broken?
- Contact Maintainers
- Contact Site Admins
- Software Vendor?
- Report Abuse
- Download

IronPython
- Software Specific:
- Software Site
- Software License
- Software Mailing List
- Package Specific:
- Package Source
- Package outdated?
- Package broken?
- Contact Maintainers
- Contact Site Admins
- Software Vendor?
- Report Abuse
- Download
Downloads:
16,958
Downloads of v 2.7.11:
1,592
Maintainer(s):
Software Author(s):
- IronPython Contributors
- Microsoft
Edit Package
To edit the metadata for a package, please upload an updated version of the package.
Chocolatey's Community Package Repository currently does not allow updating package metadata on the website. This helps ensure that the package itself (and the source used to build the package) remains the one true source of package metadata.
This does require that you increment the package version.
All Checks are Passing
2 Passing Test
To install IronPython, run the following command from the command line or from PowerShell:
To upgrade IronPython, run the following command from the command line or from PowerShell:
To uninstall IronPython, run the following command from the command line or from PowerShell:
NOTE: This applies to both open source and commercial editions of Chocolatey.
1. Ensure you are set for organizational deployment
Please see the organizational deployment guide
2. Get the package into your environment-
Open Source or Commercial:
- Proxy Repository - Create a proxy nuget repository on Nexus, Artifactory Pro, or a proxy Chocolatey repository on ProGet. Point your upstream to https://community.chocolatey.org/api/v2. Packages cache on first access automatically. Make sure your choco clients are using your proxy repository as a source and NOT the default community repository. See source command for more information.
- You can also just download the package and push it to a repository Download
-
Open Source
- Download the Package Download
- Follow manual internalization instructions
-
Package Internalizer (C4B)
- Run
choco download ironpython --internalize --source=https://community.chocolatey.org/api/v2
(additional options) - Run
choco push --source="'http://internal/odata/repo'"
for package and dependencies - Automate package internalization
- Run
3. Enter your internal repository url
(this should look similar to https://community.chocolatey.org/api/v2)
4. Choose your deployment method:
choco upgrade ironpython -y --source="'STEP 3 URL'" [other options]
See options you can pass to upgrade.
See best practices for scripting.
Add this to a PowerShell script or use a Batch script with tools and in places where you are calling directly to Chocolatey. If you are integrating, keep in mind enhanced exit codes.
If you do use a PowerShell script, use the following to ensure bad exit codes are shown as failures:
choco upgrade ironpython -y --source="'STEP 3 URL'"
$exitCode = $LASTEXITCODE
Write-Verbose "Exit code was $exitCode"
$validExitCodes = @(0, 1605, 1614, 1641, 3010)
if ($validExitCodes -contains $exitCode) {
Exit 0
}
Exit $exitCode
- name: Ensure ironpython installed
win_chocolatey:
name: ironpython
state: present
version: 2.7.11
source: STEP 3 URL
See docs at https://docs.ansible.com/ansible/latest/modules/win_chocolatey_module.html.
chocolatey_package 'ironpython' do
action :install
version '2.7.11'
source 'STEP 3 URL'
end
See docs at https://docs.chef.io/resource_chocolatey_package.html.
Chocolatey::Ensure-Package
(
Name: ironpython,
Version: 2.7.11,
Source: STEP 3 URL
);
Requires Otter Chocolatey Extension. See docs at https://inedo.com/den/otter/chocolatey.
cChocoPackageInstaller ironpython
{
Name = 'ironpython'
Ensure = 'Present'
Version = '2.7.11'
Source = 'STEP 3 URL'
}
Requires cChoco DSC Resource. See docs at https://github.com/chocolatey/cChoco.
package { 'ironpython':
provider => 'chocolatey',
ensure => '2.7.11',
source => 'STEP 3 URL',
}
Requires Puppet Chocolatey Provider module. See docs at https://forge.puppet.com/puppetlabs/chocolatey.
salt '*' chocolatey.install ironpython version="2.7.11" source="STEP 3 URL"
See docs at https://docs.saltstack.com/en/latest/ref/modules/all/salt.modules.chocolatey.html.
5. If applicable - Chocolatey configuration/installation
See infrastructure management matrix for Chocolatey configuration elements and examples.
This package was approved as a trusted package on 19 Nov 2020.
IronPython is an open-source implementation of the Python programming language which is tightly integrated with the .NET Framework. IronPython can use the .NET Framework and Python libraries, and other .NET languages can use Python code just as easily.
md5: 33C7A7897AC17C6BB2DD7A70756E8801 | sha1: ACFC7A4B095CC5541494FBA59F407CDC98C6DCB4 | sha256: B530D7EE6B5B5CD4BAB686B2A068EAAEC1757AD355B0400D2AACC23C4E2CD530 | sha512: 491B4F33BFD1E63D93FF832B3493F69D8AE93B9D2DAC5DAD9461A76E464DDE05212F133E603302964C821207C1F79EEB4BF3107201E5B6D418D0271FC269A8CB
md5: EBB90F59291A5675E8AB1CA03D563024 | sha1: DBE17AB1C7ADA5439693A26DFD9805B7C6006916 | sha256: 5F32B1D69F5E6E741CB15F5054811A580ADED15149164309CAA1D0E8CAA7D364 | sha512: A7B3CA7C2C8A9E0AE6296F76E19670D2417B1FE30E4668D980AC2F79B5CFFA6E566B8922640184BC68B0C17CE958F479C3AE8F452A88B1CD76047406DBEC4E02
md5: 2EC8B6F478E732D34F6458B47CC83016 | sha1: 9D833125C1B4FCAE8B49C8DEB13D681D0BBC63F7 | sha256: 0718C32F2184FE2B177A9F1DFADFCFFCC2C1755DE84F659730BA237E13339AF7 | sha512: 6AD6A512471702967E67B76075EF7E458500A7EC6A3DA08841D8108ED287B1F9A20144420303EF95D12910F5318B3A7E8051E35F947732F3910CD281D2E366AA
md5: A0003530EE16B4C72866B7C859C9CF6E | sha1: 9EBE5B79B76736ABBA529CF4BE6D5E73880233F1 | sha256: 3CF7F28E42FB3FD0C820BAC823C45B2F8B8FE1DEBB8D2E0F3830D9A2311BC4C1 | sha512: E7F06611FC2690B5042D300C583273DF343313E4EBCD6AC2B40D9B2375FDDB101A2366CF856E412A4ACB9E62CDD611F146DEB2A72E07512C49269C3B52F1BD32
md5: ED68E4C824307F42B1895019D6374A24 | sha1: D70BA060CF885A30F0AEA52EEF0CDD1F79B23E14 | sha256: 870578E3BF025EAA7F21C0DBAC4224D8AF7FE267F57248D9402E736ED349D868 | sha512: F00B7354975FD65A7938055AD19A21FB2C2E9C0DC2ECCE9C0AA4C645D5533BDBD488A6945ECF7829F7484E520A1F255E8C4ABBF5F15224B0C5041D44D15D0176
md5: 180494415D6F7C2A468CFD1163D5785F | sha1: 8BE5E5BB95D444456B704A74583063D9CFF76B71 | sha256: CA1FBC863C36A8E977736B00424253B417834290C15DEAC3280176285A49ECB3 | sha512: 021A149A1FD9B2B75E0BD68FF29B1347FDE80E10EE742A465658330D56C69936A294E487FEBF3051A17E96A6B364FD71BB0C2FF45FE768478C457C1279E3EFD4
md5: 3DBA49A55D5ECB7988EB3D4BD8381E02 | sha1: EFF3D86D912E2FF13F9FD5549DE4D0E64AF72955 | sha256: 3E442CDA613415AEDF80B8A1CFA4181BF4B85C548C043B88334E4067DD6600A6 | sha512: CD4ABF34D7949B5CCE606E7D0118C1D2D18E9110DE0BAA4AD963F9414D9FA4A09758C9874C400524E4299EEAE036553A4D39CBA8DBAB3F15B4FEF790EBF8212E
md5: BF4C00F06EDFEAA068F2EA9D56E77012 | sha1: 2FF1C80250C46E6479F68E57F15FCB809FDAB5C2 | sha256: 4D86F4B5747A3C67D830BED6E3CC5A0E2E9FD3169A8915EF97FFCB27286645D6 | sha512: 73D451599812D3698D24F2D0471F316BF64D13B177929EAF6E53F508BCD053B9644F56DE54EA3E5C0B7D87DD2240B4512B8FEB7789033D1E28DAD3498B1AA016
md5: F80C98A91E564C456DC62A5C5022C792 | sha1: 1E24946058EE93BB1920E672CE99C15F65A02ED2 | sha256: C0F852FA065B8D7E2F54A1845C4B80A65A05B4CF5AC670CBB5754173B1A86E40 | sha512: DD3B4BFAA8BC8C6896B379ABD129FF2A7C06F19D9F5064DE2732E62973F9E253FD9375E183BF5ED66A8CCA503A9041EE47A6B309FDED8E1EC41465F2802B2258
md5: 56C02CA018022884C6A6ECBF21853BA9 | sha1: E82520170BF37C5C26CF58A88BCD00EE31EAC953 | sha256: 87CA4725F12E8C030392BE0164A521940ED353BC60CD34725F8BC0747BB7C069 | sha512: A6ECABDD5E9A75BBA2FC4C582B23FD9DF533771DABADDC7AC9E6743852F0C850F524FBF767DE5550DC7E7BBFC767D484B0411FCB8EF36CD00CC058A00E25ED50
# Copyright 2007 Google, Inc. All Rights Reserved.
# Licensed to PSF under a Contributor Agreement.
"""Abstract Base Classes (ABCs) according to PEP 3119."""
import types
from _weakrefset import WeakSet
# Instance of old-style class
class _C: pass
_InstanceType = type(_C())
def abstractmethod(funcobj):
"""A decorator indicating abstract methods.
Requires that the metaclass is ABCMeta or derived from it. A
class that has a metaclass derived from ABCMeta cannot be
instantiated unless all of its abstract methods are overridden.
The abstract methods can be called using any of the normal
'super' call mechanisms.
Usage:
class C:
__metaclass__ = ABCMeta
@abstractmethod
def my_abstract_method(self, ...):
...
"""
funcobj.__isabstractmethod__ = True
return funcobj
class abstractproperty(property):
"""A decorator indicating abstract properties.
Requires that the metaclass is ABCMeta or derived from it. A
class that has a metaclass derived from ABCMeta cannot be
instantiated unless all of its abstract properties are overridden.
The abstract properties can be called using any of the normal
'super' call mechanisms.
Usage:
class C:
__metaclass__ = ABCMeta
@abstractproperty
def my_abstract_property(self):
...
This defines a read-only property; you can also define a read-write
abstract property using the 'long' form of property declaration:
class C:
__metaclass__ = ABCMeta
def getx(self): ...
def setx(self, value): ...
x = abstractproperty(getx, setx)
"""
__isabstractmethod__ = True
class ABCMeta(type):
"""Metaclass for defining Abstract Base Classes (ABCs).
Use this metaclass to create an ABC. An ABC can be subclassed
directly, and then acts as a mix-in class. You can also register
unrelated concrete classes (even built-in classes) and unrelated
ABCs as 'virtual subclasses' -- these and their descendants will
be considered subclasses of the registering ABC by the built-in
issubclass() function, but the registering ABC won't show up in
their MRO (Method Resolution Order) nor will method
implementations defined by the registering ABC be callable (not
even via super()).
"""
# A global counter that is incremented each time a class is
# registered as a virtual subclass of anything. It forces the
# negative cache to be cleared before its next use.
_abc_invalidation_counter = 0
def __new__(mcls, name, bases, namespace):
cls = super(ABCMeta, mcls).__new__(mcls, name, bases, namespace)
# Compute set of abstract method names
abstracts = set(name
for name, value in namespace.items()
if getattr(value, "__isabstractmethod__", False))
for base in bases:
for name in getattr(base, "__abstractmethods__", set()):
value = getattr(cls, name, None)
if getattr(value, "__isabstractmethod__", False):
abstracts.add(name)
cls.__abstractmethods__ = frozenset(abstracts)
# Set up inheritance registry
cls._abc_registry = WeakSet()
cls._abc_cache = WeakSet()
cls._abc_negative_cache = WeakSet()
cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
return cls
def register(cls, subclass):
"""Register a virtual subclass of an ABC."""
if not isinstance(subclass, (type, types.ClassType)):
raise TypeError("Can only register classes")
if issubclass(subclass, cls):
return # Already a subclass
# Subtle: test for cycles *after* testing for "already a subclass";
# this means we allow X.register(X) and interpret it as a no-op.
if issubclass(cls, subclass):
# This would create a cycle, which is bad for the algorithm below
raise RuntimeError("Refusing to create an inheritance cycle")
cls._abc_registry.add(subclass)
ABCMeta._abc_invalidation_counter += 1 # Invalidate negative cache
def _dump_registry(cls, file=None):
"""Debug helper to print the ABC registry."""
print >> file, "Class: %s.%s" % (cls.__module__, cls.__name__)
print >> file, "Inv.counter: %s" % ABCMeta._abc_invalidation_counter
for name in sorted(cls.__dict__.keys()):
if name.startswith("_abc_"):
value = getattr(cls, name)
print >> file, "%s: %r" % (name, value)
def __instancecheck__(cls, instance):
"""Override for isinstance(instance, cls)."""
# Inline the cache checking when it's simple.
subclass = getattr(instance, '__class__', None)
if subclass is not None and subclass in cls._abc_cache:
return True
subtype = type(instance)
# Old-style instances
if subtype is _InstanceType:
subtype = subclass
if subtype is subclass or subclass is None:
if (cls._abc_negative_cache_version ==
ABCMeta._abc_invalidation_counter and
subtype in cls._abc_negative_cache):
return False
# Fall back to the subclass check.
return cls.__subclasscheck__(subtype)
return (cls.__subclasscheck__(subclass) or
cls.__subclasscheck__(subtype))
def __subclasscheck__(cls, subclass):
"""Override for issubclass(subclass, cls)."""
# Check cache
if subclass in cls._abc_cache:
return True
# Check negative cache; may have to invalidate
if cls._abc_negative_cache_version < ABCMeta._abc_invalidation_counter:
# Invalidate the negative cache
cls._abc_negative_cache = WeakSet()
cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
elif subclass in cls._abc_negative_cache:
return False
# Check the subclass hook
ok = cls.__subclasshook__(subclass)
if ok is not NotImplemented:
assert isinstance(ok, bool)
if ok:
cls._abc_cache.add(subclass)
else:
cls._abc_negative_cache.add(subclass)
return ok
# Check if it's a direct subclass
if cls in getattr(subclass, '__mro__', ()):
cls._abc_cache.add(subclass)
return True
# Check if it's a subclass of a registered class (recursive)
for rcls in cls._abc_registry:
if issubclass(subclass, rcls):
cls._abc_cache.add(subclass)
return True
# Check if it's a subclass of a subclass (recursive)
for scls in cls.__subclasses__():
if issubclass(subclass, scls):
cls._abc_cache.add(subclass)
return True
# No dice; update negative cache
cls._abc_negative_cache.add(subclass)
return False
"""Stuff to parse AIFF-C and AIFF files.
Unless explicitly stated otherwise, the description below is true
both for AIFF-C files and AIFF files.
An AIFF-C file has the following structure.
+-----------------+
| FORM |
+-----------------+
| <size> |
+----+------------+
| | AIFC |
| +------------+
| | <chunks> |
| | . |
| | . |
| | . |
+----+------------+
An AIFF file has the string "AIFF" instead of "AIFC".
A chunk consists of an identifier (4 bytes) followed by a size (4 bytes,
big endian order), followed by the data. The size field does not include
the size of the 8 byte header.
The following chunk types are recognized.
FVER
<version number of AIFF-C defining document> (AIFF-C only).
MARK
<# of markers> (2 bytes)
list of markers:
<marker ID> (2 bytes, must be > 0)
<position> (4 bytes)
<marker name> ("pstring")
COMM
<# of channels> (2 bytes)
<# of sound frames> (4 bytes)
<size of the samples> (2 bytes)
<sampling frequency> (10 bytes, IEEE 80-bit extended
floating point)
in AIFF-C files only:
<compression type> (4 bytes)
<human-readable version of compression type> ("pstring")
SSND
<offset> (4 bytes, not used by this program)
<blocksize> (4 bytes, not used by this program)
<sound data>
A pstring consists of 1 byte length, a string of characters, and 0 or 1
byte pad to make the total length even.
Usage.
Reading AIFF files:
f = aifc.open(file, 'r')
where file is either the name of a file or an open file pointer.
The open file pointer must have methods read(), seek(), and close().
In some types of audio files, if the setpos() method is not used,
the seek() method is not necessary.
This returns an instance of a class with the following public methods:
getnchannels() -- returns number of audio channels (1 for
mono, 2 for stereo)
getsampwidth() -- returns sample width in bytes
getframerate() -- returns sampling frequency
getnframes() -- returns number of audio frames
getcomptype() -- returns compression type ('NONE' for AIFF files)
getcompname() -- returns human-readable version of
compression type ('not compressed' for AIFF files)
getparams() -- returns a tuple consisting of all of the
above in the above order
getmarkers() -- get the list of marks in the audio file or None
if there are no marks
getmark(id) -- get mark with the specified id (raises an error
if the mark does not exist)
readframes(n) -- returns at most n frames of audio
rewind() -- rewind to the beginning of the audio stream
setpos(pos) -- seek to the specified position
tell() -- return the current position
close() -- close the instance (make it unusable)
The position returned by tell(), the position given to setpos() and
the position of marks are all compatible and have nothing to do with
the actual position in the file.
The close() method is called automatically when the class instance
is destroyed.
Writing AIFF files:
f = aifc.open(file, 'w')
where file is either the name of a file or an open file pointer.
The open file pointer must have methods write(), tell(), seek(), and
close().
This returns an instance of a class with the following public methods:
aiff() -- create an AIFF file (AIFF-C default)
aifc() -- create an AIFF-C file
setnchannels(n) -- set the number of channels
setsampwidth(n) -- set the sample width
setframerate(n) -- set the frame rate
setnframes(n) -- set the number of frames
setcomptype(type, name)
-- set the compression type and the
human-readable compression type
setparams(tuple)
-- set all parameters at once
setmark(id, pos, name)
-- add specified mark to the list of marks
tell() -- return current position in output file (useful
in combination with setmark())
writeframesraw(data)
-- write audio frames without pathing up the
file header
writeframes(data)
-- write audio frames and patch up the file header
close() -- patch up the file header and close the
output file
You should set the parameters before the first writeframesraw or
writeframes. The total number of frames does not need to be set,
but when it is set to the correct value, the header does not have to
be patched up.
It is best to first set all parameters, perhaps possibly the
compression type, and then write audio frames using writeframesraw.
When all frames have been written, either call writeframes('') or
close() to patch up the sizes in the header.
Marks can be added anytime. If there are any marks, you must call
close() after all frames have been written.
The close() method is called automatically when the class instance
is destroyed.
When a file is opened with the extension '.aiff', an AIFF file is
written, otherwise an AIFF-C file is written. This default can be
changed by calling aiff() or aifc() before the first writeframes or
writeframesraw.
"""
import struct
import __builtin__
__all__ = ["Error","open","openfp"]
class Error(Exception):
pass
_AIFC_version = 0xA2805140L # Version 1 of AIFF-C
def _read_long(file):
try:
return struct.unpack('>l', file.read(4))[0]
except struct.error:
raise EOFError
def _read_ulong(file):
try:
return struct.unpack('>L', file.read(4))[0]
except struct.error:
raise EOFError
def _read_short(file):
try:
return struct.unpack('>h', file.read(2))[0]
except struct.error:
raise EOFError
def _read_ushort(file):
try:
return struct.unpack('>H', file.read(2))[0]
except struct.error:
raise EOFError
def _read_string(file):
length = ord(file.read(1))
if length == 0:
data = ''
else:
data = file.read(length)
if length & 1 == 0:
dummy = file.read(1)
return data
_HUGE_VAL = 1.79769313486231e+308 # See <limits.h>
def _read_float(f): # 10 bytes
expon = _read_short(f) # 2 bytes
sign = 1
if expon < 0:
sign = -1
expon = expon + 0x8000
himant = _read_ulong(f) # 4 bytes
lomant = _read_ulong(f) # 4 bytes
if expon == himant == lomant == 0:
f = 0.0
elif expon == 0x7FFF:
f = _HUGE_VAL
else:
expon = expon - 16383
f = (himant * 0x100000000L + lomant) * pow(2.0, expon - 63)
return sign * f
def _write_short(f, x):
f.write(struct.pack('>h', x))
def _write_ushort(f, x):
f.write(struct.pack('>H', x))
def _write_long(f, x):
f.write(struct.pack('>l', x))
def _write_ulong(f, x):
f.write(struct.pack('>L', x))
def _write_string(f, s):
if len(s) > 255:
raise ValueError("string exceeds maximum pstring length")
f.write(struct.pack('B', len(s)))
f.write(s)
if len(s) & 1 == 0:
f.write(chr(0))
def _write_float(f, x):
import math
if x < 0:
sign = 0x8000
x = x * -1
else:
sign = 0
if x == 0:
expon = 0
himant = 0
lomant = 0
else:
fmant, expon = math.frexp(x)
if expon > 16384 or fmant >= 1 or fmant != fmant: # Infinity or NaN
expon = sign|0x7FFF
himant = 0
lomant = 0
else: # Finite
expon = expon + 16382
if expon < 0: # denormalized
fmant = math.ldexp(fmant, expon)
expon = 0
expon = expon | sign
fmant = math.ldexp(fmant, 32)
fsmant = math.floor(fmant)
himant = long(fsmant)
fmant = math.ldexp(fmant - fsmant, 32)
fsmant = math.floor(fmant)
lomant = long(fsmant)
_write_ushort(f, expon)
_write_ulong(f, himant)
_write_ulong(f, lomant)
from chunk import Chunk
class Aifc_read:
# Variables used in this class:
#
# These variables are available to the user though appropriate
# methods of this class:
# _file -- the open file with methods read(), close(), and seek()
# set through the __init__() method
# _nchannels -- the number of audio channels
# available through the getnchannels() method
# _nframes -- the number of audio frames
# available through the getnframes() method
# _sampwidth -- the number of bytes per audio sample
# available through the getsampwidth() method
# _framerate -- the sampling frequency
# available through the getframerate() method
# _comptype -- the AIFF-C compression type ('NONE' if AIFF)
# available through the getcomptype() method
# _compname -- the human-readable AIFF-C compression type
# available through the getcomptype() method
# _markers -- the marks in the audio file
# available through the getmarkers() and getmark()
# methods
# _soundpos -- the position in the audio stream
# available through the tell() method, set through the
# setpos() method
#
# These variables are used internally only:
# _version -- the AIFF-C version number
# _decomp -- the decompressor from builtin module cl
# _comm_chunk_read -- 1 iff the COMM chunk has been read
# _aifc -- 1 iff reading an AIFF-C file
# _ssnd_seek_needed -- 1 iff positioned correctly in audio
# file for readframes()
# _ssnd_chunk -- instantiation of a chunk class for the SSND chunk
# _framesize -- size of one frame in the file
_file = None # Set here since __del__ checks it
def initfp(self, file):
self._version = 0
self._decomp = None
self._convert = None
self._markers = []
self._soundpos = 0
self._file = file
chunk = Chunk(file)
if chunk.getname() != 'FORM':
raise Error, 'file does not start with FORM id'
formdata = chunk.read(4)
if formdata == 'AIFF':
self._aifc = 0
elif formdata == 'AIFC':
self._aifc = 1
else:
raise Error, 'not an AIFF or AIFF-C file'
self._comm_chunk_read = 0
self._ssnd_chunk = None
while 1:
self._ssnd_seek_needed = 1
try:
chunk = Chunk(self._file)
except EOFError:
break
chunkname = chunk.getname()
if chunkname == 'COMM':
self._read_comm_chunk(chunk)
self._comm_chunk_read = 1
elif chunkname == 'SSND':
self._ssnd_chunk = chunk
dummy = chunk.read(8)
self._ssnd_seek_needed = 0
elif chunkname == 'FVER':
self._version = _read_ulong(chunk)
elif chunkname == 'MARK':
self._readmark(chunk)
chunk.skip()
if not self._comm_chunk_read or not self._ssnd_chunk:
raise Error, 'COMM chunk and/or SSND chunk missing'
if self._aifc and self._decomp:
import cl
params = [cl.ORIGINAL_FORMAT, 0,
cl.BITS_PER_COMPONENT, self._sampwidth * 8,
cl.FRAME_RATE, self._framerate]
if self._nchannels == 1:
params[1] = cl.MONO
elif self._nchannels == 2:
params[1] = cl.STEREO_INTERLEAVED
else:
raise Error, 'cannot compress more than 2 channels'
self._decomp.SetParams(params)
def __init__(self, f):
if isinstance(f, basestring):
f = __builtin__.open(f, 'rb')
try:
self.initfp(f)
except:
f.close()
raise
else:
# assume it is an open file object already
self.initfp(f)
#
# User visible methods.
#
def getfp(self):
return self._file
def rewind(self):
self._ssnd_seek_needed = 1
self._soundpos = 0
def close(self):
decomp = self._decomp
try:
if decomp:
self._decomp = None
decomp.CloseDecompressor()
finally:
self._file.close()
def tell(self):
return self._soundpos
def getnchannels(self):
return self._nchannels
def getnframes(self):
return self._nframes
def getsampwidth(self):
return self._sampwidth
def getframerate(self):
return self._framerate
def getcomptype(self):
return self._comptype
def getcompname(self):
return self._compname
## def getversion(self):
## return self._version
def getparams(self):
return self.getnchannels(), self.getsampwidth(), \
self.getframerate(), self.getnframes(), \
self.getcomptype(), self.getcompname()
def getmarkers(self):
if len(self._markers) == 0:
return None
return self._markers
def getmark(self, id):
for marker in self._markers:
if id == marker[0]:
return marker
raise Error, 'marker %r does not exist' % (id,)
def setpos(self, pos):
if pos < 0 or pos > self._nframes:
raise Error, 'position not in range'
self._soundpos = pos
self._ssnd_seek_needed = 1
def readframes(self, nframes):
if self._ssnd_seek_needed:
self._ssnd_chunk.seek(0)
dummy = self._ssnd_chunk.read(8)
pos = self._soundpos * self._framesize
if pos:
self._ssnd_chunk.seek(pos + 8)
self._ssnd_seek_needed = 0
if nframes == 0:
return ''
data = self._ssnd_chunk.read(nframes * self._framesize)
if self._convert and data:
data = self._convert(data)
self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
return data
#
# Internal methods.
#
def _decomp_data(self, data):
import cl
dummy = self._decomp.SetParam(cl.FRAME_BUFFER_SIZE,
len(data) * 2)
return self._decomp.Decompress(len(data) // self._nchannels,
data)
def _ulaw2lin(self, data):
import audioop
return audioop.ulaw2lin(data, 2)
def _adpcm2lin(self, data):
import audioop
if not hasattr(self, '_adpcmstate'):
# first time
self._adpcmstate = None
data, self._adpcmstate = audioop.adpcm2lin(data, 2,
self._adpcmstate)
return data
def _read_comm_chunk(self, chunk):
self._nchannels = _read_short(chunk)
self._nframes = _read_long(chunk)
self._sampwidth = (_read_short(chunk) + 7) // 8
self._framerate = int(_read_float(chunk))
self._framesize = self._nchannels * self._sampwidth
if self._aifc:
#DEBUG: SGI's soundeditor produces a bad size :-(
kludge = 0
if chunk.chunksize == 18:
kludge = 1
print 'Warning: bad COMM chunk size'
chunk.chunksize = 23
#DEBUG end
self._comptype = chunk.read(4)
#DEBUG start
if kludge:
length = ord(chunk.file.read(1))
if length & 1 == 0:
length = length + 1
chunk.chunksize = chunk.chunksize + length
chunk.file.seek(-1, 1)
#DEBUG end
self._compname = _read_string(chunk)
if self._comptype != 'NONE':
if self._comptype == 'G722':
try:
import audioop
except ImportError:
pass
else:
self._convert = self._adpcm2lin
self._sampwidth = 2
return
# for ULAW and ALAW try Compression Library
try:
import cl
except ImportError:
if self._comptype in ('ULAW', 'ulaw'):
try:
import audioop
self._convert = self._ulaw2lin
self._sampwidth = 2
return
except ImportError:
pass
raise Error, 'cannot read compressed AIFF-C files'
if self._comptype in ('ULAW', 'ulaw'):
scheme = cl.G711_ULAW
elif self._comptype in ('ALAW', 'alaw'):
scheme = cl.G711_ALAW
else:
raise Error, 'unsupported compression type'
self._decomp = cl.OpenDecompressor(scheme)
self._convert = self._decomp_data
self._sampwidth = 2
else:
self._comptype = 'NONE'
self._compname = 'not compressed'
def _readmark(self, chunk):
nmarkers = _read_short(chunk)
# Some files appear to contain invalid counts.
# Cope with this by testing for EOF.
try:
for i in range(nmarkers):
id = _read_short(chunk)
pos = _read_long(chunk)
name = _read_string(chunk)
if pos or name:
# some files appear to have
# dummy markers consisting of
# a position 0 and name ''
self._markers.append((id, pos, name))
except EOFError:
print 'Warning: MARK chunk contains only',
print len(self._markers),
if len(self._markers) == 1: print 'marker',
else: print 'markers',
print 'instead of', nmarkers
class Aifc_write:
# Variables used in this class:
#
# These variables are user settable through appropriate methods
# of this class:
# _file -- the open file with methods write(), close(), tell(), seek()
# set through the __init__() method
# _comptype -- the AIFF-C compression type ('NONE' in AIFF)
# set through the setcomptype() or setparams() method
# _compname -- the human-readable AIFF-C compression type
# set through the setcomptype() or setparams() method
# _nchannels -- the number of audio channels
# set through the setnchannels() or setparams() method
# _sampwidth -- the number of bytes per audio sample
# set through the setsampwidth() or setparams() method
# _framerate -- the sampling frequency
# set through the setframerate() or setparams() method
# _nframes -- the number of audio frames written to the header
# set through the setnframes() or setparams() method
# _aifc -- whether we're writing an AIFF-C file or an AIFF file
# set through the aifc() method, reset through the
# aiff() method
#
# These variables are used internally only:
# _version -- the AIFF-C version number
# _comp -- the compressor from builtin module cl
# _nframeswritten -- the number of audio frames actually written
# _datalength -- the size of the audio samples written to the header
# _datawritten -- the size of the audio samples actually written
_file = None # Set here since __del__ checks it
def __init__(self, f):
if isinstance(f, basestring):
filename = f
f = __builtin__.open(f, 'wb')
else:
# else, assume it is an open file object already
filename = '???'
self.initfp(f)
if filename[-5:] == '.aiff':
self._aifc = 0
else:
self._aifc = 1
def initfp(self, file):
self._file = file
self._version = _AIFC_version
self._comptype = 'NONE'
self._compname = 'not compressed'
self._comp = None
self._convert = None
self._nchannels = 0
self._sampwidth = 0
self._framerate = 0
self._nframes = 0
self._nframeswritten = 0
self._datawritten = 0
self._datalength = 0
self._markers = []
self._marklength = 0
self._aifc = 1 # AIFF-C is default
def __del__(self):
if self._file:
self.close()
#
# User visible methods.
#
def aiff(self):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
self._aifc = 0
def aifc(self):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
self._aifc = 1
def setnchannels(self, nchannels):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if nchannels < 1:
raise Error, 'bad # of channels'
self._nchannels = nchannels
def getnchannels(self):
if not self._nchannels:
raise Error, 'number of channels not set'
return self._nchannels
def setsampwidth(self, sampwidth):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if sampwidth < 1 or sampwidth > 4:
raise Error, 'bad sample width'
self._sampwidth = sampwidth
def getsampwidth(self):
if not self._sampwidth:
raise Error, 'sample width not set'
return self._sampwidth
def setframerate(self, framerate):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if framerate <= 0:
raise Error, 'bad frame rate'
self._framerate = framerate
def getframerate(self):
if not self._framerate:
raise Error, 'frame rate not set'
return self._framerate
def setnframes(self, nframes):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
self._nframes = nframes
def getnframes(self):
return self._nframeswritten
def setcomptype(self, comptype, compname):
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if comptype not in ('NONE', 'ULAW', 'ulaw', 'ALAW', 'alaw', 'G722'):
raise Error, 'unsupported compression type'
self._comptype = comptype
self._compname = compname
def getcomptype(self):
return self._comptype
def getcompname(self):
return self._compname
## def setversion(self, version):
## if self._nframeswritten:
## raise Error, 'cannot change parameters after starting to write'
## self._version = version
def setparams(self, info):
nchannels, sampwidth, framerate, nframes, comptype, compname = info
if self._nframeswritten:
raise Error, 'cannot change parameters after starting to write'
if comptype not in ('NONE', 'ULAW', 'ulaw', 'ALAW', 'alaw', 'G722'):
raise Error, 'unsupported compression type'
self.setnchannels(nchannels)
self.setsampwidth(sampwidth)
self.setframerate(framerate)
self.setnframes(nframes)
self.setcomptype(comptype, compname)
def getparams(self):
if not self._nchannels or not self._sampwidth or not self._framerate:
raise Error, 'not all parameters set'
return self._nchannels, self._sampwidth, self._framerate, \
self._nframes, self._comptype, self._compname
def setmark(self, id, pos, name):
if id <= 0:
raise Error, 'marker ID must be > 0'
if pos < 0:
raise Error, 'marker position must be >= 0'
if type(name) != type(''):
raise Error, 'marker name must be a string'
for i in range(len(self._markers)):
if id == self._markers[i][0]:
self._markers[i] = id, pos, name
return
self._markers.append((id, pos, name))
def getmark(self, id):
for marker in self._markers:
if id == marker[0]:
return marker
raise Error, 'marker %r does not exist' % (id,)
def getmarkers(self):
if len(self._markers) == 0:
return None
return self._markers
def tell(self):
return self._nframeswritten
def writeframesraw(self, data):
self._ensure_header_written(len(data))
nframes = len(data) // (self._sampwidth * self._nchannels)
if self._convert:
data = self._convert(data)
self._file.write(data)
self._nframeswritten = self._nframeswritten + nframes
self._datawritten = self._datawritten + len(data)
def writeframes(self, data):
self.writeframesraw(data)
if self._nframeswritten != self._nframes or \
self._datalength != self._datawritten:
self._patchheader()
def close(self):
if self._file is None:
return
try:
self._ensure_header_written(0)
if self._datawritten & 1:
# quick pad to even size
self._file.write(chr(0))
self._datawritten = self._datawritten + 1
self._writemarkers()
if self._nframeswritten != self._nframes or \
self._datalength != self._datawritten or \
self._marklength:
self._patchheader()
if self._comp:
self._comp.CloseCompressor()
self._comp = None
finally:
# Prevent ref cycles
self._convert = None
f = self._file
self._file = None
f.close()
#
# Internal methods.
#
def _comp_data(self, data):
import cl
dummy = self._comp.SetParam(cl.FRAME_BUFFER_SIZE, len(data))
dummy = self._comp.SetParam(cl.COMPRESSED_BUFFER_SIZE, len(data))
return self._comp.Compress(self._nframes, data)
def _lin2ulaw(self, data):
import audioop
return audioop.lin2ulaw(data, 2)
def _lin2adpcm(self, data):
import audioop
if not hasattr(self, '_adpcmstate'):
self._adpcmstate = None
data, self._adpcmstate = audioop.lin2adpcm(data, 2,
self._adpcmstate)
return data
def _ensure_header_written(self, datasize):
if not self._nframeswritten:
if self._comptype in ('ULAW', 'ulaw', 'ALAW', 'alaw'):
if not self._sampwidth:
self._sampwidth = 2
if self._sampwidth != 2:
raise Error, 'sample width must be 2 when compressing with ULAW or ALAW'
if self._comptype == 'G722':
if not self._sampwidth:
self._sampwidth = 2
if self._sampwidth != 2:
raise Error, 'sample width must be 2 when compressing with G7.22 (ADPCM)'
if not self._nchannels:
raise Error, '# channels not specified'
if not self._sampwidth:
raise Error, 'sample width not specified'
if not self._framerate:
raise Error, 'sampling rate not specified'
self._write_header(datasize)
def _init_compression(self):
if self._comptype == 'G722':
self._convert = self._lin2adpcm
return
try:
import cl
except ImportError:
if self._comptype in ('ULAW', 'ulaw'):
try:
import audioop
self._convert = self._lin2ulaw
return
except ImportError:
pass
raise Error, 'cannot write compressed AIFF-C files'
if self._comptype in ('ULAW', 'ulaw'):
scheme = cl.G711_ULAW
elif self._comptype in ('ALAW', 'alaw'):
scheme = cl.G711_ALAW
else:
raise Error, 'unsupported compression type'
self._comp = cl.OpenCompressor(scheme)
params = [cl.ORIGINAL_FORMAT, 0,
cl.BITS_PER_COMPONENT, self._sampwidth * 8,
cl.FRAME_RATE, self._framerate,
cl.FRAME_BUFFER_SIZE, 100,
cl.COMPRESSED_BUFFER_SIZE, 100]
if self._nchannels == 1:
params[1] = cl.MONO
elif self._nchannels == 2:
params[1] = cl.STEREO_INTERLEAVED
else:
raise Error, 'cannot compress more than 2 channels'
self._comp.SetParams(params)
# the compressor produces a header which we ignore
dummy = self._comp.Compress(0, '')
self._convert = self._comp_data
def _write_header(self, initlength):
if self._aifc and self._comptype != 'NONE':
self._init_compression()
self._file.write('FORM')
if not self._nframes:
self._nframes = initlength // (self._nchannels * self._sampwidth)
self._datalength = self._nframes * self._nchannels * self._sampwidth
if self._datalength & 1:
self._datalength = self._datalength + 1
if self._aifc:
if self._comptype in ('ULAW', 'ulaw', 'ALAW', 'alaw'):
self._datalength = self._datalength // 2
if self._datalength & 1:
self._datalength = self._datalength + 1
elif self._comptype == 'G722':
self._datalength = (self._datalength + 3) // 4
if self._datalength & 1:
self._datalength = self._datalength + 1
try:
self._form_length_pos = self._file.tell()
except (AttributeError, IOError):
self._form_length_pos = None
commlength = self._write_form_length(self._datalength)
if self._aifc:
self._file.write('AIFC')
self._file.write('FVER')
_write_ulong(self._file, 4)
_write_ulong(self._file, self._version)
else:
self._file.write('AIFF')
self._file.write('COMM')
_write_ulong(self._file, commlength)
_write_short(self._file, self._nchannels)
if self._form_length_pos is not None:
self._nframes_pos = self._file.tell()
_write_ulong(self._file, self._nframes)
if self._comptype in ('ULAW', 'ulaw', 'ALAW', 'alaw', 'G722'):
_write_short(self._file, 8)
else:
_write_short(self._file, self._sampwidth * 8)
_write_float(self._file, self._framerate)
if self._aifc:
self._file.write(self._comptype)
_write_string(self._file, self._compname)
self._file.write('SSND')
if self._form_length_pos is not None:
self._ssnd_length_pos = self._file.tell()
_write_ulong(self._file, self._datalength + 8)
_write_ulong(self._file, 0)
_write_ulong(self._file, 0)
def _write_form_length(self, datalength):
if self._aifc:
commlength = 18 + 5 + len(self._compname)
if commlength & 1:
commlength = commlength + 1
verslength = 12
else:
commlength = 18
verslength = 0
_write_ulong(self._file, 4 + verslength + self._marklength + \
8 + commlength + 16 + datalength)
return commlength
def _patchheader(self):
curpos = self._file.tell()
if self._datawritten & 1:
datalength = self._datawritten + 1
self._file.write(chr(0))
else:
datalength = self._datawritten
if datalength == self._datalength and \
self._nframes == self._nframeswritten and \
self._marklength == 0:
self._file.seek(curpos, 0)
return
self._file.seek(self._form_length_pos, 0)
dummy = self._write_form_length(datalength)
self._file.seek(self._nframes_pos, 0)
_write_ulong(self._file, self._nframeswritten)
self._file.seek(self._ssnd_length_pos, 0)
_write_ulong(self._file, datalength + 8)
self._file.seek(curpos, 0)
self._nframes = self._nframeswritten
self._datalength = datalength
def _writemarkers(self):
if len(self._markers) == 0:
return
self._file.write('MARK')
length = 2
for marker in self._markers:
id, pos, name = marker
length = length + len(name) + 1 + 6
if len(name) & 1 == 0:
length = length + 1
_write_ulong(self._file, length)
self._marklength = length + 8
_write_short(self._file, len(self._markers))
for marker in self._markers:
id, pos, name = marker
_write_short(self._file, id)
_write_ulong(self._file, pos)
_write_string(self._file, name)
def open(f, mode=None):
if mode is None:
if hasattr(f, 'mode'):
mode = f.mode
else:
mode = 'rb'
if mode in ('r', 'rb'):
return Aifc_read(f)
elif mode in ('w', 'wb'):
return Aifc_write(f)
else:
raise Error, "mode must be 'r', 'rb', 'w', or 'wb'"
openfp = open # B/W compatibility
if __name__ == '__main__':
import sys
if not sys.argv[1:]:
sys.argv.append('/usr/demos/data/audio/bach.aiff')
fn = sys.argv[1]
f = open(fn, 'r')
try:
print "Reading", fn
print "nchannels =", f.getnchannels()
print "nframes =", f.getnframes()
print "sampwidth =", f.getsampwidth()
print "framerate =", f.getframerate()
print "comptype =", f.getcomptype()
print "compname =", f.getcompname()
if sys.argv[2:]:
gn = sys.argv[2]
print "Writing", gn
g = open(gn, 'w')
try:
g.setparams(f.getparams())
while 1:
data = f.readframes(1024)
if not data:
break
g.writeframes(data)
finally:
g.close()
print "Done."
finally:
f.close()
import webbrowser
webbrowser.open("http://xkcd.com/353/")
"""Generic interface to all dbm clones.
Instead of
import dbm
d = dbm.open(file, 'w', 0666)
use
import anydbm
d = anydbm.open(file, 'w')
The returned object is a dbhash, gdbm, dbm or dumbdbm object,
dependent on the type of database being opened (determined by whichdb
module) in the case of an existing dbm. If the dbm does not exist and
the create or new flag ('c' or 'n') was specified, the dbm type will
be determined by the availability of the modules (tested in the above
order).
It has the following interface (key and data are strings):
d[key] = data # store data at key (may override data at
# existing key)
data = d[key] # retrieve data at key (raise KeyError if no
# such key)
del d[key] # delete data stored at key (raises KeyError
# if no such key)
flag = key in d # true if the key exists
list = d.keys() # return a list of all existing keys (slow!)
Future versions may change the order in which implementations are
tested for existence, and add interfaces to other dbm-like
implementations.
"""
class error(Exception):
pass
_names = ['dbhash', 'gdbm', 'dbm', 'dumbdbm']
_errors = [error]
_defaultmod = None
for _name in _names:
try:
_mod = __import__(_name)
except ImportError:
continue
if not _defaultmod:
_defaultmod = _mod
_errors.append(_mod.error)
if not _defaultmod:
raise ImportError, "no dbm clone found; tried %s" % _names
error = tuple(_errors)
def open(file, flag='r', mode=0666):
"""Open or create database at path given by *file*.
Optional argument *flag* can be 'r' (default) for read-only access, 'w'
for read-write access of an existing database, 'c' for read-write access
to a new or existing database, and 'n' for read-write access to a new
database.
Note: 'r' and 'w' fail if the database doesn't exist; 'c' creates it
only if it doesn't exist; and 'n' always creates a new database.
"""
# guess the type of an existing database
from whichdb import whichdb
result=whichdb(file)
if result is None:
# db doesn't exist
if 'c' in flag or 'n' in flag:
# file doesn't exist and the new
# flag was used so use default type
mod = _defaultmod
else:
raise error, "need 'c' or 'n' flag to open new db"
elif result == "":
# db type cannot be determined
raise error, "db type could not be determined"
else:
mod = __import__(result)
return mod.open(file, flag, mode)
# Author: Steven J. Bethard <[email protected]>.
"""Command-line parsing library
This module is an optparse-inspired command-line parsing library that:
- handles both optional and positional arguments
- produces highly informative usage messages
- supports parsers that dispatch to sub-parsers
The following is a simple usage example that sums integers from the
command-line and writes the result to a file::
parser = argparse.ArgumentParser(
description='sum the integers at the command line')
parser.add_argument(
'integers', metavar='int', nargs='+', type=int,
help='an integer to be summed')
parser.add_argument(
'--log', default=sys.stdout, type=argparse.FileType('w'),
help='the file where the sum should be written')
args = parser.parse_args()
args.log.write('%s' % sum(args.integers))
args.log.close()
The module contains the following public classes:
- ArgumentParser -- The main entry point for command-line parsing. As the
example above shows, the add_argument() method is used to populate
the parser with actions for optional and positional arguments. Then
the parse_args() method is invoked to convert the args at the
command-line into an object with attributes.
- ArgumentError -- The exception raised by ArgumentParser objects when
there are errors with the parser's actions. Errors raised while
parsing the command-line are caught by ArgumentParser and emitted
as command-line messages.
- FileType -- A factory for defining types of files to be created. As the
example above shows, instances of FileType are typically passed as
the type= argument of add_argument() calls.
- Action -- The base class for parser actions. Typically actions are
selected by passing strings like 'store_true' or 'append_const' to
the action= argument of add_argument(). However, for greater
customization of ArgumentParser actions, subclasses of Action may
be defined and passed as the action= argument.
- HelpFormatter, RawDescriptionHelpFormatter, RawTextHelpFormatter,
ArgumentDefaultsHelpFormatter -- Formatter classes which
may be passed as the formatter_class= argument to the
ArgumentParser constructor. HelpFormatter is the default,
RawDescriptionHelpFormatter and RawTextHelpFormatter tell the parser
not to change the formatting for help text, and
ArgumentDefaultsHelpFormatter adds information about argument defaults
to the help.
All other classes in this module are considered implementation details.
(Also note that HelpFormatter and RawDescriptionHelpFormatter are only
considered public as object names -- the API of the formatter objects is
still considered an implementation detail.)
"""
__version__ = '1.1'
__all__ = [
'ArgumentParser',
'ArgumentError',
'ArgumentTypeError',
'FileType',
'HelpFormatter',
'ArgumentDefaultsHelpFormatter',
'RawDescriptionHelpFormatter',
'RawTextHelpFormatter',
'Namespace',
'Action',
'ONE_OR_MORE',
'OPTIONAL',
'PARSER',
'REMAINDER',
'SUPPRESS',
'ZERO_OR_MORE',
]
import collections as _collections
import copy as _copy
import os as _os
import re as _re
import sys as _sys
import textwrap as _textwrap
from gettext import gettext as _
def _callable(obj):
return hasattr(obj, '__call__') or hasattr(obj, '__bases__')
SUPPRESS = '==SUPPRESS=='
OPTIONAL = '?'
ZERO_OR_MORE = '*'
ONE_OR_MORE = '+'
PARSER = 'A...'
REMAINDER = '...'
_UNRECOGNIZED_ARGS_ATTR = '_unrecognized_args'
# =============================
# Utility functions and classes
# =============================
class _AttributeHolder(object):
"""Abstract base class that provides __repr__.
The __repr__ method returns a string in the format::
ClassName(attr=name, attr=name, ...)
The attributes are determined either by a class-level attribute,
'_kwarg_names', or by inspecting the instance __dict__.
"""
def __repr__(self):
type_name = type(self).__name__
arg_strings = []
for arg in self._get_args():
arg_strings.append(repr(arg))
for name, value in self._get_kwargs():
arg_strings.append('%s=%r' % (name, value))
return '%s(%s)' % (type_name, ', '.join(arg_strings))
def _get_kwargs(self):
return sorted(self.__dict__.items())
def _get_args(self):
return []
def _ensure_value(namespace, name, value):
if getattr(namespace, name, None) is None:
setattr(namespace, name, value)
return getattr(namespace, name)
# ===============
# Formatting Help
# ===============
class HelpFormatter(object):
"""Formatter for generating usage messages and argument help strings.
Only the name of this class is considered a public API. All the methods
provided by the class are considered an implementation detail.
"""
def __init__(self,
prog,
indent_increment=2,
max_help_position=24,
width=None):
# default setting for width
if width is None:
try:
width = int(_os.environ['COLUMNS'])
except (KeyError, ValueError):
width = 80
width -= 2
self._prog = prog
self._indent_increment = indent_increment
self._max_help_position = max_help_position
self._max_help_position = min(max_help_position,
max(width - 20, indent_increment * 2))
self._width = width
self._current_indent = 0
self._level = 0
self._action_max_length = 0
self._root_section = self._Section(self, None)
self._current_section = self._root_section
self._whitespace_matcher = _re.compile(r'\s+')
self._long_break_matcher = _re.compile(r'\n\n\n+')
# ===============================
# Section and indentation methods
# ===============================
def _indent(self):
self._current_indent += self._indent_increment
self._level += 1
def _dedent(self):
self._current_indent -= self._indent_increment
assert self._current_indent >= 0, 'Indent decreased below 0.'
self._level -= 1
class _Section(object):
def __init__(self, formatter, parent, heading=None):
self.formatter = formatter
self.parent = parent
self.heading = heading
self.items = []
def format_help(self):
# format the indented section
if self.parent is not None:
self.formatter._indent()
join = self.formatter._join_parts
for func, args in self.items:
func(*args)
item_help = join([func(*args) for func, args in self.items])
if self.parent is not None:
self.formatter._dedent()
# return nothing if the section was empty
if not item_help:
return ''
# add the heading if the section was non-empty
if self.heading is not SUPPRESS and self.heading is not None:
current_indent = self.formatter._current_indent
heading = '%*s%s:\n' % (current_indent, '', self.heading)
else:
heading = ''
# join the section-initial newline, the heading and the help
return join(['\n', heading, item_help, '\n'])
def _add_item(self, func, args):
self._current_section.items.append((func, args))
# ========================
# Message building methods
# ========================
def start_section(self, heading):
self._indent()
section = self._Section(self, self._current_section, heading)
self._add_item(section.format_help, [])
self._current_section = section
def end_section(self):
self._current_section = self._current_section.parent
self._dedent()
def add_text(self, text):
if text is not SUPPRESS and text is not None:
self._add_item(self._format_text, [text])
def add_usage(self, usage, actions, groups, prefix=None):
if usage is not SUPPRESS:
args = usage, actions, groups, prefix
self._add_item(self._format_usage, args)
def add_argument(self, action):
if action.help is not SUPPRESS:
# find all invocations
get_invocation = self._format_action_invocation
invocations = [get_invocation(action)]
for subaction in self._iter_indented_subactions(action):
invocations.append(get_invocation(subaction))
# update the maximum item length
invocation_length = max([len(s) for s in invocations])
action_length = invocation_length + self._current_indent
self._action_max_length = max(self._action_max_length,
action_length)
# add the item to the list
self._add_item(self._format_action, [action])
def add_arguments(self, actions):
for action in actions:
self.add_argument(action)
# =======================
# Help-formatting methods
# =======================
def format_help(self):
help = self._root_section.format_help()
if help:
help = self._long_break_matcher.sub('\n\n', help)
help = help.strip('\n') + '\n'
return help
def _join_parts(self, part_strings):
return ''.join([part
for part in part_strings
if part and part is not SUPPRESS])
def _format_usage(self, usage, actions, groups, prefix):
if prefix is None:
prefix = _('usage: ')
# if usage is specified, use that
if usage is not None:
usage = usage % dict(prog=self._prog)
# if no optionals or positionals are available, usage is just prog
elif usage is None and not actions:
usage = '%(prog)s' % dict(prog=self._prog)
# if optionals and positionals are available, calculate usage
elif usage is None:
prog = '%(prog)s' % dict(prog=self._prog)
# split optionals from positionals
optionals = []
positionals = []
for action in actions:
if action.option_strings:
optionals.append(action)
else:
positionals.append(action)
# build full usage string
format = self._format_actions_usage
action_usage = format(optionals + positionals, groups)
usage = ' '.join([s for s in [prog, action_usage] if s])
# wrap the usage parts if it's too long
text_width = self._width - self._current_indent
if len(prefix) + len(usage) > text_width:
# break usage into wrappable parts
part_regexp = (
r'\(.*?\)+(?=\s|$)|'
r'\[.*?\]+(?=\s|$)|'
r'\S+'
)
opt_usage = format(optionals, groups)
pos_usage = format(positionals, groups)
opt_parts = _re.findall(part_regexp, opt_usage)
pos_parts = _re.findall(part_regexp, pos_usage)
assert ' '.join(opt_parts) == opt_usage
assert ' '.join(pos_parts) == pos_usage
# helper for wrapping lines
def get_lines(parts, indent, prefix=None):
lines = []
line = []
if prefix is not None:
line_len = len(prefix) - 1
else:
line_len = len(indent) - 1
for part in parts:
if line_len + 1 + len(part) > text_width and line:
lines.append(indent + ' '.join(line))
line = []
line_len = len(indent) - 1
line.append(part)
line_len += len(part) + 1
if line:
lines.append(indent + ' '.join(line))
if prefix is not None:
lines[0] = lines[0][len(indent):]
return lines
# if prog is short, follow it with optionals or positionals
if len(prefix) + len(prog) <= 0.75 * text_width:
indent = ' ' * (len(prefix) + len(prog) + 1)
if opt_parts:
lines = get_lines([prog] + opt_parts, indent, prefix)
lines.extend(get_lines(pos_parts, indent))
elif pos_parts:
lines = get_lines([prog] + pos_parts, indent, prefix)
else:
lines = [prog]
# if prog is long, put it on its own line
else:
indent = ' ' * len(prefix)
parts = opt_parts + pos_parts
lines = get_lines(parts, indent)
if len(lines) > 1:
lines = []
lines.extend(get_lines(opt_parts, indent))
lines.extend(get_lines(pos_parts, indent))
lines = [prog] + lines
# join lines into usage
usage = '\n'.join(lines)
# prefix with 'usage:'
return '%s%s\n\n' % (prefix, usage)
def _format_actions_usage(self, actions, groups):
# find group indices and identify actions in groups
group_actions = set()
inserts = {}
for group in groups:
try:
start = actions.index(group._group_actions[0])
except ValueError:
continue
else:
end = start + len(group._group_actions)
if actions[start:end] == group._group_actions:
for action in group._group_actions:
group_actions.add(action)
if not group.required:
if start in inserts:
inserts[start] += ' ['
else:
inserts[start] = '['
inserts[end] = ']'
else:
if start in inserts:
inserts[start] += ' ('
else:
inserts[start] = '('
inserts[end] = ')'
for i in range(start + 1, end):
inserts[i] = '|'
# collect all actions format strings
parts = []
for i, action in enumerate(actions):
# suppressed arguments are marked with None
# remove | separators for suppressed arguments
if action.help is SUPPRESS:
parts.append(None)
if inserts.get(i) == '|':
inserts.pop(i)
elif inserts.get(i + 1) == '|':
inserts.pop(i + 1)
# produce all arg strings
elif not action.option_strings:
part = self._format_args(action, action.dest)
# if it's in a group, strip the outer []
if action in group_actions:
if part[0] == '[' and part[-1] == ']':
part = part[1:-1]
# add the action string to the list
parts.append(part)
# produce the first way to invoke the option in brackets
else:
option_string = action.option_strings[0]
# if the Optional doesn't take a value, format is:
# -s or --long
if action.nargs == 0:
part = '%s' % option_string
# if the Optional takes a value, format is:
# -s ARGS or --long ARGS
else:
default = action.dest.upper()
args_string = self._format_args(action, default)
part = '%s %s' % (option_string, args_string)
# make it look optional if it's not required or in a group
if not action.required and action not in group_actions:
part = '[%s]' % part
# add the action string to the list
parts.append(part)
# insert things at the necessary indices
for i in sorted(inserts, reverse=True):
parts[i:i] = [inserts[i]]
# join all the action items with spaces
text = ' '.join([item for item in parts if item is not None])
# clean up separators for mutually exclusive groups
open = r'[\[(]'
close = r'[\])]'
text = _re.sub(r'(%s) ' % open, r'\1', text)
text = _re.sub(r' (%s)' % close, r'\1', text)
text = _re.sub(r'%s *%s' % (open, close), r'', text)
text = _re.sub(r'\(([^|]*)\)', r'\1', text)
text = text.strip()
# return the text
return text
def _format_text(self, text):
if '%(prog)' in text:
text = text % dict(prog=self._prog)
text_width = max(self._width - self._current_indent, 11)
indent = ' ' * self._current_indent
return self._fill_text(text, text_width, indent) + '\n\n'
def _format_action(self, action):
# determine the required width and the entry label
help_position = min(self._action_max_length + 2,
self._max_help_position)
help_width = max(self._width - help_position, 11)
action_width = help_position - self._current_indent - 2
action_header = self._format_action_invocation(action)
# ho nelp; start on same line and add a final newline
if not action.help:
tup = self._current_indent, '', action_header
action_header = '%*s%s\n' % tup
# short action name; start on the same line and pad two spaces
elif len(action_header) <= action_width:
tup = self._current_indent, '', action_width, action_header
action_header = '%*s%-*s ' % tup
indent_first = 0
# long action name; start on the next line
else:
tup = self._current_indent, '', action_header
action_header = '%*s%s\n' % tup
indent_first = help_position
# collect the pieces of the action help
parts = [action_header]
# if there was help for the action, add lines of help text
if action.help:
help_text = self._expand_help(action)
help_lines = self._split_lines(help_text, help_width)
parts.append('%*s%s\n' % (indent_first, '', help_lines[0]))
for line in help_lines[1:]:
parts.append('%*s%s\n' % (help_position, '', line))
# or add a newline if the description doesn't end with one
elif not action_header.endswith('\n'):
parts.append('\n')
# if there are any sub-actions, add their help as well
for subaction in self._iter_indented_subactions(action):
parts.append(self._format_action(subaction))
# return a single string
return self._join_parts(parts)
def _format_action_invocation(self, action):
if not action.option_strings:
metavar, = self._metavar_formatter(action, action.dest)(1)
return metavar
else:
parts = []
# if the Optional doesn't take a value, format is:
# -s, --long
if action.nargs == 0:
parts.extend(action.option_strings)
# if the Optional takes a value, format is:
# -s ARGS, --long ARGS
else:
default = action.dest.upper()
args_string = self._format_args(action, default)
for option_string in action.option_strings:
parts.append('%s %s' % (option_string, args_string))
return ', '.join(parts)
def _metavar_formatter(self, action, default_metavar):
if action.metavar is not None:
result = action.metavar
elif action.choices is not None:
choice_strs = [str(choice) for choice in action.choices]
result = '{%s}' % ','.join(choice_strs)
else:
result = default_metavar
def format(tuple_size):
if isinstance(result, tuple):
return result
else:
return (result, ) * tuple_size
return format
def _format_args(self, action, default_metavar):
get_metavar = self._metavar_formatter(action, default_metavar)
if action.nargs is None:
result = '%s' % get_metavar(1)
elif action.nargs == OPTIONAL:
result = '[%s]' % get_metavar(1)
elif action.nargs == ZERO_OR_MORE:
result = '[%s [%s ...]]' % get_metavar(2)
elif action.nargs == ONE_OR_MORE:
result = '%s [%s ...]' % get_metavar(2)
elif action.nargs == REMAINDER:
result = '...'
elif action.nargs == PARSER:
result = '%s ...' % get_metavar(1)
else:
formats = ['%s' for _ in range(action.nargs)]
result = ' '.join(formats) % get_metavar(action.nargs)
return result
def _expand_help(self, action):
params = dict(vars(action), prog=self._prog)
for name in list(params):
if params[name] is SUPPRESS:
del params[name]
for name in list(params):
if hasattr(params[name], '__name__'):
params[name] = params[name].__name__
if params.get('choices') is not None:
choices_str = ', '.join([str(c) for c in params['choices']])
params['choices'] = choices_str
return self._get_help_string(action) % params
def _iter_indented_subactions(self, action):
try:
get_subactions = action._get_subactions
except AttributeError:
pass
else:
self._indent()
for subaction in get_subactions():
yield subaction
self._dedent()
def _split_lines(self, text, width):
text = self._whitespace_matcher.sub(' ', text).strip()
return _textwrap.wrap(text, width)
def _fill_text(self, text, width, indent):
text = self._whitespace_matcher.sub(' ', text).strip()
return _textwrap.fill(text, width, initial_indent=indent,
subsequent_indent=indent)
def _get_help_string(self, action):
return action.help
class RawDescriptionHelpFormatter(HelpFormatter):
"""Help message formatter which retains any formatting in descriptions.
Only the name of this class is considered a public API. All the methods
provided by the class are considered an implementation detail.
"""
def _fill_text(self, text, width, indent):
return ''.join([indent + line for line in text.splitlines(True)])
class RawTextHelpFormatter(RawDescriptionHelpFormatter):
"""Help message formatter which retains formatting of all help text.
Only the name of this class is considered a public API. All the methods
provided by the class are considered an implementation detail.
"""
def _split_lines(self, text, width):
return text.splitlines()
class ArgumentDefaultsHelpFormatter(HelpFormatter):
"""Help message formatter which adds default values to argument help.
Only the name of this class is considered a public API. All the methods
provided by the class are considered an implementation detail.
"""
def _get_help_string(self, action):
help = action.help
if '%(default)' not in action.help:
if action.default is not SUPPRESS:
defaulting_nargs = [OPTIONAL, ZERO_OR_MORE]
if action.option_strings or action.nargs in defaulting_nargs:
help += ' (default: %(default)s)'
return help
# =====================
# Options and Arguments
# =====================
def _get_action_name(argument):
if argument is None:
return None
elif argument.option_strings:
return '/'.join(argument.option_strings)
elif argument.metavar not in (None, SUPPRESS):
return argument.metavar
elif argument.dest not in (None, SUPPRESS):
return argument.dest
else:
return None
class ArgumentError(Exception):
"""An error from creating or using an argument (optional or positional).
The string value of this exception is the message, augmented with
information about the argument that caused it.
"""
def __init__(self, argument, message):
self.argument_name = _get_action_name(argument)
self.message = message
def __str__(self):
if self.argument_name is None:
format = '%(message)s'
else:
format = 'argument %(argument_name)s: %(message)s'
return format % dict(message=self.message,
argument_name=self.argument_name)
class ArgumentTypeError(Exception):
"""An error from trying to convert a command line string to a type."""
pass
# ==============
# Action classes
# ==============
class Action(_AttributeHolder):
"""Information about how to convert command line strings to Python objects.
Action objects are used by an ArgumentParser to represent the information
needed to parse a single argument from one or more strings from the
command line. The keyword arguments to the Action constructor are also
all attributes of Action instances.
Keyword Arguments:
- option_strings -- A list of command-line option strings which
should be associated with this action.
- dest -- The name of the attribute to hold the created object(s)
- nargs -- The number of command-line arguments that should be
consumed. By default, one argument will be consumed and a single
value will be produced. Other values include:
- N (an integer) consumes N arguments (and produces a list)
- '?' consumes zero or one arguments
- '*' consumes zero or more arguments (and produces a list)
- '+' consumes one or more arguments (and produces a list)
Note that the difference between the default and nargs=1 is that
with the default, a single value will be produced, while with
nargs=1, a list containing a single value will be produced.
- const -- The value to be produced if the option is specified and the
option uses an action that takes no values.
- default -- The value to be produced if the option is not specified.
- type -- A callable that accepts a single string argument, and
returns the converted value. The standard Python types str, int,
float, and complex are useful examples of such callables. If None,
str is used.
- choices -- A container of values that should be allowed. If not None,
after a command-line argument has been converted to the appropriate
type, an exception will be raised if it is not a member of this
collection.
- required -- True if the action must always be specified at the
command line. This is only meaningful for optional command-line
arguments.
- help -- The help string describing the argument.
- metavar -- The name to be used for the option's argument with the
help string. If None, the 'dest' value will be used as the name.
"""
def __init__(self,
option_strings,
dest,
nargs=None,
const=None,
default=None,
type=None,
choices=None,
required=False,
help=None,
metavar=None):
self.option_strings = option_strings
self.dest = dest
self.nargs = nargs
self.const = const
self.default = default
self.type = type
self.choices = choices
self.required = required
self.help = help
self.metavar = metavar
def _get_kwargs(self):
names = [
'option_strings',
'dest',
'nargs',
'const',
'default',
'type',
'choices',
'help',
'metavar',
]
return [(name, getattr(self, name)) for name in names]
def __call__(self, parser, namespace, values, option_string=None):
raise NotImplementedError(_('.__call__() not defined'))
class _StoreAction(Action):
def __init__(self,
option_strings,
dest,
nargs=None,
const=None,
default=None,
type=None,
choices=None,
required=False,
help=None,
metavar=None):
if nargs == 0:
raise ValueError('nargs for store actions must be > 0; if you '
'have nothing to store, actions such as store '
'true or store const may be more appropriate')
if const is not None and nargs != OPTIONAL:
raise ValueError('nargs must be %r to supply const' % OPTIONAL)
super(_StoreAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=nargs,
const=const,
default=default,
type=type,
choices=choices,
required=required,
help=help,
metavar=metavar)
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values)
class _StoreConstAction(Action):
def __init__(self,
option_strings,
dest,
const,
default=None,
required=False,
help=None,
metavar=None):
super(_StoreConstAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=0,
const=const,
default=default,
required=required,
help=help)
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, self.const)
class _StoreTrueAction(_StoreConstAction):
def __init__(self,
option_strings,
dest,
default=False,
required=False,
help=None):
super(_StoreTrueAction, self).__init__(
option_strings=option_strings,
dest=dest,
const=True,
default=default,
required=required,
help=help)
class _StoreFalseAction(_StoreConstAction):
def __init__(self,
option_strings,
dest,
default=True,
required=False,
help=None):
super(_StoreFalseAction, self).__init__(
option_strings=option_strings,
dest=dest,
const=False,
default=default,
required=required,
help=help)
class _AppendAction(Action):
def __init__(self,
option_strings,
dest,
nargs=None,
const=None,
default=None,
type=None,
choices=None,
required=False,
help=None,
metavar=None):
if nargs == 0:
raise ValueError('nargs for append actions must be > 0; if arg '
'strings are not supplying the value to append, '
'the append const action may be more appropriate')
if const is not None and nargs != OPTIONAL:
raise ValueError('nargs must be %r to supply const' % OPTIONAL)
super(_AppendAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=nargs,
const=const,
default=default,
type=type,
choices=choices,
required=required,
help=help,
metavar=metavar)
def __call__(self, parser, namespace, values, option_string=None):
items = _copy.copy(_ensure_value(namespace, self.dest, []))
items.append(values)
setattr(namespace, self.dest, items)
class _AppendConstAction(Action):
def __init__(self,
option_strings,
dest,
const,
default=None,
required=False,
help=None,
metavar=None):
super(_AppendConstAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=0,
const=const,
default=default,
required=required,
help=help,
metavar=metavar)
def __call__(self, parser, namespace, values, option_string=None):
items = _copy.copy(_ensure_value(namespace, self.dest, []))
items.append(self.const)
setattr(namespace, self.dest, items)
class _CountAction(Action):
def __init__(self,
option_strings,
dest,
default=None,
required=False,
help=None):
super(_CountAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=0,
default=default,
required=required,
help=help)
def __call__(self, parser, namespace, values, option_string=None):
new_count = _ensure_value(namespace, self.dest, 0) + 1
setattr(namespace, self.dest, new_count)
class _HelpAction(Action):
def __init__(self,
option_strings,
dest=SUPPRESS,
default=SUPPRESS,
help=None):
super(_HelpAction, self).__init__(
option_strings=option_strings,
dest=dest,
default=default,
nargs=0,
help=help)
def __call__(self, parser, namespace, values, option_string=None):
parser.print_help()
parser.exit()
class _VersionAction(Action):
def __init__(self,
option_strings,
version=None,
dest=SUPPRESS,
default=SUPPRESS,
help="show program's version number and exit"):
super(_VersionAction, self).__init__(
option_strings=option_strings,
dest=dest,
default=default,
nargs=0,
help=help)
self.version = version
def __call__(self, parser, namespace, values, option_string=None):
version = self.version
if version is None:
version = parser.version
formatter = parser._get_formatter()
formatter.add_text(version)
parser.exit(message=formatter.format_help())
class _SubParsersAction(Action):
class _ChoicesPseudoAction(Action):
def __init__(self, name, help):
sup = super(_SubParsersAction._ChoicesPseudoAction, self)
sup.__init__(option_strings=[], dest=name, help=help)
def __init__(self,
option_strings,
prog,
parser_class,
dest=SUPPRESS,
help=None,
metavar=None):
self._prog_prefix = prog
self._parser_class = parser_class
self._name_parser_map = _collections.OrderedDict()
self._choices_actions = []
super(_SubParsersAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=PARSER,
choices=self._name_parser_map,
help=help,
metavar=metavar)
def add_parser(self, name, **kwargs):
# set prog from the existing prefix
if kwargs.get('prog') is None:
kwargs['prog'] = '%s %s' % (self._prog_prefix, name)
# create a pseudo-action to hold the choice help
if 'help' in kwargs:
help = kwargs.pop('help')
choice_action = self._ChoicesPseudoAction(name, help)
self._choices_actions.append(choice_action)
# create the parser and add it to the map
parser = self._parser_class(**kwargs)
self._name_parser_map[name] = parser
return parser
def _get_subactions(self):
return self._choices_actions
def __call__(self, parser, namespace, values, option_string=None):
parser_name = values[0]
arg_strings = values[1:]
# set the parser name if requested
if self.dest is not SUPPRESS:
setattr(namespace, self.dest, parser_name)
# select the parser
try:
parser = self._name_parser_map[parser_name]
except KeyError:
tup = parser_name, ', '.join(self._name_parser_map)
msg = _('unknown parser %r (choices: %s)') % tup
raise ArgumentError(self, msg)
# parse all the remaining options into the namespace
# store any unrecognized options on the object, so that the top
# level parser can decide what to do with them
# In case this subparser defines new defaults, we parse them
# in a new namespace object and then update the original
# namespace for the relevant parts.
subnamespace, arg_strings = parser.parse_known_args(arg_strings, None)
for key, value in vars(subnamespace).items():
setattr(namespace, key, value)
if arg_strings:
vars(namespace).setdefault(_UNRECOGNIZED_ARGS_ATTR, [])
getattr(namespace, _UNRECOGNIZED_ARGS_ATTR).extend(arg_strings)
# ==============
# Type classes
# ==============
class FileType(object):
"""Factory for creating file object types
Instances of FileType are typically passed as type= arguments to the
ArgumentParser add_argument() method.
Keyword Arguments:
- mode -- A string indicating how the file is to be opened. Accepts the
same values as the builtin open() function.
- bufsize -- The file's desired buffer size. Accepts the same values as
the builtin open() function.
"""
def __init__(self, mode='r', bufsize=-1):
self._mode = mode
self._bufsize = bufsize
def __call__(self, string):
# the special argument "-" means sys.std{in,out}
if string == '-':
if 'r' in self._mode:
return _sys.stdin
elif 'w' in self._mode:
return _sys.stdout
else:
msg = _('argument "-" with mode %r') % self._mode
raise ValueError(msg)
# all other arguments are used as file names
try:
return open(string, self._mode, self._bufsize)
except IOError as e:
message = _("can't open '%s': %s")
raise ArgumentTypeError(message % (string, e))
def __repr__(self):
args = self._mode, self._bufsize
args_str = ', '.join(repr(arg) for arg in args if arg != -1)
return '%s(%s)' % (type(self).__name__, args_str)
# ===========================
# Optional and Positional Parsing
# ===========================
class Namespace(_AttributeHolder):
"""Simple object for storing attributes.
Implements equality by attribute names and values, and provides a simple
string representation.
"""
def __init__(self, **kwargs):
for name in kwargs:
setattr(self, name, kwargs[name])
__hash__ = None
def __eq__(self, other):
if not isinstance(other, Namespace):
return NotImplemented
return vars(self) == vars(other)
def __ne__(self, other):
if not isinstance(other, Namespace):
return NotImplemented
return not (self == other)
def __contains__(self, key):
return key in self.__dict__
class _ActionsContainer(object):
def __init__(self,
description,
prefix_chars,
argument_default,
conflict_handler):
super(_ActionsContainer, self).__init__()
self.description = description
self.argument_default = argument_default
self.prefix_chars = prefix_chars
self.conflict_handler = conflict_handler
# set up registries
self._registries = {}
# register actions
self.register('action', None, _StoreAction)
self.register('action', 'store', _StoreAction)
self.register('action', 'store_const', _StoreConstAction)
self.register('action', 'store_true', _StoreTrueAction)
self.register('action', 'store_false', _StoreFalseAction)
self.register('action', 'append', _AppendAction)
self.register('action', 'append_const', _AppendConstAction)
self.register('action', 'count', _CountAction)
self.register('action', 'help', _HelpAction)
self.register('action', 'version', _VersionAction)
self.register('action', 'parsers', _SubParsersAction)
# raise an exception if the conflict handler is invalid
self._get_handler()
# action storage
self._actions = []
self._option_string_actions = {}
# groups
self._action_groups = []
self._mutually_exclusive_groups = []
# defaults storage
self._defaults = {}
# determines whether an "option" looks like a negative number
self._negative_number_matcher = _re.compile(r'^-\d+$|^-\d*\.\d+$')
# whether or not there are any optionals that look like negative
# numbers -- uses a list so it can be shared and edited
self._has_negative_number_optionals = []
# ====================
# Registration methods
# ====================
def register(self, registry_name, value, object):
registry = self._registries.setdefault(registry_name, {})
registry[value] = object
def _registry_get(self, registry_name, value, default=None):
return self._registries[registry_name].get(value, default)
# ==================================
# Namespace default accessor methods
# ==================================
def set_defaults(self, **kwargs):
self._defaults.update(kwargs)
# if these defaults match any existing arguments, replace
# the previous default on the object with the new one
for action in self._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
def get_default(self, dest):
for action in self._actions:
if action.dest == dest and action.default is not None:
return action.default
return self._defaults.get(dest, None)
# =======================
# Adding argument actions
# =======================
def add_argument(self, *args, **kwargs):
"""
add_argument(dest, ..., name=value, ...)
add_argument(option_string, option_string, ..., name=value, ...)
"""
# if no positional args are supplied or only one is supplied and
# it doesn't look like an option string, parse a positional
# argument
chars = self.prefix_chars
if not args or len(args) == 1 and args[0][0] not in chars:
if args and 'dest' in kwargs:
raise ValueError('dest supplied twice for positional argument')
kwargs = self._get_positional_kwargs(*args, **kwargs)
# otherwise, we're adding an optional argument
else:
kwargs = self._get_optional_kwargs(*args, **kwargs)
# if no default was supplied, use the parser-level default
if 'default' not in kwargs:
dest = kwargs['dest']
if dest in self._defaults:
kwargs['default'] = self._defaults[dest]
elif self.argument_default is not None:
kwargs['default'] = self.argument_default
# create the action object, and add it to the parser
action_class = self._pop_action_class(kwargs)
if not _callable(action_class):
raise ValueError('unknown action "%s"' % (action_class,))
action = action_class(**kwargs)
# raise an error if the action type is not callable
type_func = self._registry_get('type', action.type, action.type)
if not _callable(type_func):
raise ValueError('%r is not callable' % (type_func,))
# raise an error if the metavar does not match the type
if hasattr(self, "_get_formatter"):
try:
self._get_formatter()._format_args(action, None)
except TypeError:
raise ValueError("length of metavar tuple does not match nargs")
return self._add_action(action)
def add_argument_group(self, *args, **kwargs):
group = _ArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def add_mutually_exclusive_group(self, **kwargs):
group = _MutuallyExclusiveGroup(self, **kwargs)
self._mutually_exclusive_groups.append(group)
return group
def _add_action(self, action):
# resolve any conflicts
self._check_conflict(action)
# add to actions list
self._actions.append(action)
action.container = self
# index the action by any option strings it has
for option_string in action.option_strings:
self._option_string_actions[option_string] = action
# set the flag if any option strings look like negative numbers
for option_string in action.option_strings:
if self._negative_number_matcher.match(option_string):
if not self._has_negative_number_optionals:
self._has_negative_number_optionals.append(True)
# return the created action
return action
def _remove_action(self, action):
self._actions.remove(action)
def _add_container_actions(self, container):
# collect groups by titles
title_group_map = {}
for group in self._action_groups:
if group.title in title_group_map:
msg = _('cannot merge actions - two groups are named %r')
raise ValueError(msg % (group.title))
title_group_map[group.title] = group
# map each action to its group
group_map = {}
for group in container._action_groups:
# if a group with the title exists, use that, otherwise
# create a new group matching the container's group
if group.title not in title_group_map:
title_group_map[group.title] = self.add_argument_group(
title=group.title,
description=group.description,
conflict_handler=group.conflict_handler)
# map the actions to their new group
for action in group._group_actions:
group_map[action] = title_group_map[group.title]
# add container's mutually exclusive groups
# NOTE: if add_mutually_exclusive_group ever gains title= and
# description= then this code will need to be expanded as above
for group in container._mutually_exclusive_groups:
mutex_group = self.add_mutually_exclusive_group(
required=group.required)
# map the actions to their new mutex group
for action in group._group_actions:
group_map[action] = mutex_group
# add all actions to this container or their group
for action in container._actions:
group_map.get(action, self)._add_action(action)
def _get_positional_kwargs(self, dest, **kwargs):
# make sure required is not specified
if 'required' in kwargs:
msg = _("'required' is an invalid argument for positionals")
raise TypeError(msg)
# mark positional arguments as required if at least one is
# always required
if kwargs.get('nargs') not in [OPTIONAL, ZERO_OR_MORE]:
kwargs['required'] = True
if kwargs.get('nargs') == ZERO_OR_MORE and 'default' not in kwargs:
kwargs['required'] = True
# return the keyword arguments with no option strings
return dict(kwargs, dest=dest, option_strings=[])
def _get_optional_kwargs(self, *args, **kwargs):
# determine short and long option strings
option_strings = []
long_option_strings = []
for option_string in args:
# error on strings that don't start with an appropriate prefix
if not option_string[0] in self.prefix_chars:
msg = _('invalid option string %r: '
'must start with a character %r')
tup = option_string, self.prefix_chars
raise ValueError(msg % tup)
# strings starting with two prefix characters are long options
option_strings.append(option_string)
if option_string[0] in self.prefix_chars:
if len(option_string) > 1:
if option_string[1] in self.prefix_chars:
long_option_strings.append(option_string)
# infer destination, '--foo-bar' -> 'foo_bar' and '-x' -> 'x'
dest = kwargs.pop('dest', None)
if dest is None:
if long_option_strings:
dest_option_string = long_option_strings[0]
else:
dest_option_string = option_strings[0]
dest = dest_option_string.lstrip(self.prefix_chars)
if not dest:
msg = _('dest= is required for options like %r')
raise ValueError(msg % option_string)
dest = dest.replace('-', '_')
# return the updated keyword arguments
return dict(kwargs, dest=dest, option_strings=option_strings)
def _pop_action_class(self, kwargs, default=None):
action = kwargs.pop('action', default)
return self._registry_get('action', action, action)
def _get_handler(self):
# determine function from conflict handler string
handler_func_name = '_handle_conflict_%s' % self.conflict_handler
try:
return getattr(self, handler_func_name)
except AttributeError:
msg = _('invalid conflict_resolution value: %r')
raise ValueError(msg % self.conflict_handler)
def _check_conflict(self, action):
# find all options that conflict with this option
confl_optionals = []
for option_string in action.option_strings:
if option_string in self._option_string_actions:
confl_optional = self._option_string_actions[option_string]
confl_optionals.append((option_string, confl_optional))
# resolve any conflicts
if confl_optionals:
conflict_handler = self._get_handler()
conflict_handler(action, confl_optionals)
def _handle_conflict_error(self, action, conflicting_actions):
message = _('conflicting option string(s): %s')
conflict_string = ', '.join([option_string
for option_string, action
in conflicting_actions])
raise ArgumentError(action, message % conflict_string)
def _handle_conflict_resolve(self, action, conflicting_actions):
# remove all conflicting options
for option_string, action in conflicting_actions:
# remove the conflicting option
action.option_strings.remove(option_string)
self._option_string_actions.pop(option_string, None)
# if the option now has no option string, remove it from the
# container holding it
if not action.option_strings:
action.container._remove_action(action)
class _ArgumentGroup(_ActionsContainer):
def __init__(self, container, title=None, description=None, **kwargs):
# add any missing keyword arguments by checking the container
update = kwargs.setdefault
update('conflict_handler', container.conflict_handler)
update('prefix_chars', container.prefix_chars)
update('argument_default', container.argument_default)
super_init = super(_ArgumentGroup, self).__init__
super_init(description=description, **kwargs)
# group attributes
self.title = title
self._group_actions = []
# share most attributes with the container
self._registries = container._registries
self._actions = container._actions
self._option_string_actions = container._option_string_actions
self._defaults = container._defaults
self._has_negative_number_optionals = \
container._has_negative_number_optionals
self._mutually_exclusive_groups = container._mutually_exclusive_groups
def _add_action(self, action):
action = super(_ArgumentGroup, self)._add_action(action)
self._group_actions.append(action)
return action
def _remove_action(self, action):
super(_ArgumentGroup, self)._remove_action(action)
self._group_actions.remove(action)
class _MutuallyExclusiveGroup(_ArgumentGroup):
def __init__(self, container, required=False):
super(_MutuallyExclusiveGroup, self).__init__(container)
self.required = required
self._container = container
def _add_action(self, action):
if action.required:
msg = _('mutually exclusive arguments must be optional')
raise ValueError(msg)
action = self._container._add_action(action)
self._group_actions.append(action)
return action
def _remove_action(self, action):
self._container._remove_action(action)
self._group_actions.remove(action)
class ArgumentParser(_AttributeHolder, _ActionsContainer):
"""Object for parsing command line strings into Python objects.
Keyword Arguments:
- prog -- The name of the program (default: sys.argv[0])
- usage -- A usage message (default: auto-generated from arguments)
- description -- A description of what the program does
- epilog -- Text following the argument descriptions
- parents -- Parsers whose arguments should be copied into this one
- formatter_class -- HelpFormatter class for printing help messages
- prefix_chars -- Characters that prefix optional arguments
- fromfile_prefix_chars -- Characters that prefix files containing
additional arguments
- argument_default -- The default value for all arguments
- conflict_handler -- String indicating how to handle conflicts
- add_help -- Add a -h/-help option
"""
def __init__(self,
prog=None,
usage=None,
description=None,
epilog=None,
version=None,
parents=[],
formatter_class=HelpFormatter,
prefix_chars='-',
fromfile_prefix_chars=None,
argument_default=None,
conflict_handler='error',
add_help=True):
if version is not None:
import warnings
warnings.warn(
"""The "version" argument to ArgumentParser is deprecated. """
"""Please use """
""""add_argument(..., action='version', version="N", ...)" """
"""instead""", DeprecationWarning)
superinit = super(ArgumentParser, self).__init__
superinit(description=description,
prefix_chars=prefix_chars,
argument_default=argument_default,
conflict_handler=conflict_handler)
# default setting for prog
if prog is None:
prog = _os.path.basename(_sys.argv[0])
self.prog = prog
self.usage = usage
self.epilog = epilog
self.version = version
self.formatter_class = formatter_class
self.fromfile_prefix_chars = fromfile_prefix_chars
self.add_help = add_help
add_group = self.add_argument_group
self._positionals = add_group(_('positional arguments'))
self._optionals = add_group(_('optional arguments'))
self._subparsers = None
# register types
def identity(string):
return string
self.register('type', None, identity)
# add help and version arguments if necessary
# (using explicit default to override global argument_default)
default_prefix = '-' if '-' in prefix_chars else prefix_chars[0]
if self.add_help:
self.add_argument(
default_prefix+'h', default_prefix*2+'help',
action='help', default=SUPPRESS,
help=_('show this help message and exit'))
if self.version:
self.add_argument(
default_prefix+'v', default_prefix*2+'version',
action='version', default=SUPPRESS,
version=self.version,
help=_("show program's version number and exit"))
# add parent arguments and defaults
for parent in parents:
self._add_container_actions(parent)
try:
defaults = parent._defaults
except AttributeError:
pass
else:
self._defaults.update(defaults)
# =======================
# Pretty __repr__ methods
# =======================
def _get_kwargs(self):
names = [
'prog',
'usage',
'description',
'version',
'formatter_class',
'conflict_handler',
'add_help',
]
return [(name, getattr(self, name)) for name in names]
# ==================================
# Optional/Positional adding methods
# ==================================
def add_subparsers(self, **kwargs):
if self._subparsers is not None:
self.error(_('cannot have multiple subparser arguments'))
# add the parser class to the arguments if it's not present
kwargs.setdefault('parser_class', type(self))
if 'title' in kwargs or 'description' in kwargs:
title = _(kwargs.pop('title', 'subcommands'))
description = _(kwargs.pop('description', None))
self._subparsers = self.add_argument_group(title, description)
else:
self._subparsers = self._positionals
# prog defaults to the usage message of this parser, skipping
# optional arguments and with no "usage:" prefix
if kwargs.get('prog') is None:
formatter = self._get_formatter()
positionals = self._get_positional_actions()
groups = self._mutually_exclusive_groups
formatter.add_usage(self.usage, positionals, groups, '')
kwargs['prog'] = formatter.format_help().strip()
# create the parsers action and add it to the positionals list
parsers_class = self._pop_action_class(kwargs, 'parsers')
action = parsers_class(option_strings=[], **kwargs)
self._subparsers._add_action(action)
# return the created parsers action
return action
def _add_action(self, action):
if action.option_strings:
self._optionals._add_action(action)
else:
self._positionals._add_action(action)
return action
def _get_optional_actions(self):
return [action
for action in self._actions
if action.option_strings]
def _get_positional_actions(self):
return [action
for action in self._actions
if not action.option_strings]
# =====================================
# Command line argument parsing methods
# =====================================
def parse_args(self, args=None, namespace=None):
args, argv = self.parse_known_args(args, namespace)
if argv:
msg = _('unrecognized arguments: %s')
self.error(msg % ' '.join(argv))
return args
def parse_known_args(self, args=None, namespace=None):
if args is None:
# args default to the system args
args = _sys.argv[1:]
else:
# make sure that args are mutable
args = list(args)
# default Namespace built from parser defaults
if namespace is None:
namespace = Namespace()
# add any action defaults that aren't present
for action in self._actions:
if action.dest is not SUPPRESS:
if not hasattr(namespace, action.dest):
if action.default is not SUPPRESS:
setattr(namespace, action.dest, action.default)
# add any parser defaults that aren't present
for dest in self._defaults:
if not hasattr(namespace, dest):
setattr(namespace, dest, self._defaults[dest])
# parse the arguments and exit if there are any errors
try:
namespace, args = self._parse_known_args(args, namespace)
if hasattr(namespace, _UNRECOGNIZED_ARGS_ATTR):
args.extend(getattr(namespace, _UNRECOGNIZED_ARGS_ATTR))
delattr(namespace, _UNRECOGNIZED_ARGS_ATTR)
return namespace, args
except ArgumentError:
err = _sys.exc_info()[1]
self.error(str(err))
def _parse_known_args(self, arg_strings, namespace):
# replace arg strings that are file references
if self.fromfile_prefix_chars is not None:
arg_strings = self._read_args_from_files(arg_strings)
# map all mutually exclusive arguments to the other arguments
# they can't occur with
action_conflicts = {}
for mutex_group in self._mutually_exclusive_groups:
group_actions = mutex_group._group_actions
for i, mutex_action in enumerate(mutex_group._group_actions):
conflicts = action_conflicts.setdefault(mutex_action, [])
conflicts.extend(group_actions[:i])
conflicts.extend(group_actions[i + 1:])
# find all option indices, and determine the arg_string_pattern
# which has an 'O' if there is an option at an index,
# an 'A' if there is an argument, or a '-' if there is a '--'
option_string_indices = {}
arg_string_pattern_parts = []
arg_strings_iter = iter(arg_strings)
for i, arg_string in enumerate(arg_strings_iter):
# all args after -- are non-options
if arg_string == '--':
arg_string_pattern_parts.append('-')
for arg_string in arg_strings_iter:
arg_string_pattern_parts.append('A')
# otherwise, add the arg to the arg strings
# and note the index if it was an option
else:
option_tuple = self._parse_optional(arg_string)
if option_tuple is None:
pattern = 'A'
else:
option_string_indices[i] = option_tuple
pattern = 'O'
arg_string_pattern_parts.append(pattern)
# join the pieces together to form the pattern
arg_strings_pattern = ''.join(arg_string_pattern_parts)
# converts arg strings to the appropriate and then takes the action
seen_actions = set()
seen_non_default_actions = set()
def take_action(action, argument_strings, option_string=None):
seen_actions.add(action)
argument_values = self._get_values(action, argument_strings)
# error if this argument is not allowed with other previously
# seen arguments, assuming that actions that use the default
# value don't really count as "present"
if argument_values is not action.default:
seen_non_default_actions.add(action)
for conflict_action in action_conflicts.get(action, []):
if conflict_action in seen_non_default_actions:
msg = _('not allowed with argument %s')
action_name = _get_action_name(conflict_action)
raise ArgumentError(action, msg % action_name)
# take the action if we didn't receive a SUPPRESS value
# (e.g. from a default)
if argument_values is not SUPPRESS:
action(self, namespace, argument_values, option_string)
# function to convert arg_strings into an optional action
def consume_optional(start_index):
# get the optional identified at this index
option_tuple = option_string_indices[start_index]
action, option_string, explicit_arg = option_tuple
# identify additional optionals in the same arg string
# (e.g. -xyz is the same as -x -y -z if no args are required)
match_argument = self._match_argument
action_tuples = []
while True:
# if we found no optional action, skip it
if action is None:
extras.append(arg_strings[start_index])
return start_index + 1
# if there is an explicit argument, try to match the
# optional's string arguments to only this
if explicit_arg is not None:
arg_count = match_argument(action, 'A')
# if the action is a single-dash option and takes no
# arguments, try to parse more single-dash options out
# of the tail of the option string
chars = self.prefix_chars
if arg_count == 0 and option_string[1] not in chars:
action_tuples.append((action, [], option_string))
char = option_string[0]
option_string = char + explicit_arg[0]
new_explicit_arg = explicit_arg[1:] or None
optionals_map = self._option_string_actions
if option_string in optionals_map:
action = optionals_map[option_string]
explicit_arg = new_explicit_arg
else:
msg = _('ignored explicit argument %r')
raise ArgumentError(action, msg % explicit_arg)
# if the action expect exactly one argument, we've
# successfully matched the option; exit the loop
elif arg_count == 1:
stop = start_index + 1
args = [explicit_arg]
action_tuples.append((action, args, option_string))
break
# error if a double-dash option did not use the
# explicit argument
else:
msg = _('ignored explicit argument %r')
raise ArgumentError(action, msg % explicit_arg)
# if there is no explicit argument, try to match the
# optional's string arguments with the following strings
# if successful, exit the loop
else:
start = start_index + 1
selected_patterns = arg_strings_pattern[start:]
arg_count = match_argument(action, selected_patterns)
stop = start + arg_count
args = arg_strings[start:stop]
action_tuples.append((action, args, option_string))
break
# add the Optional to the list and return the index at which
# the Optional's string args stopped
assert action_tuples
for action, args, option_string in action_tuples:
take_action(action, args, option_string)
return stop
# the list of Positionals left to be parsed; this is modified
# by consume_positionals()
positionals = self._get_positional_actions()
# function to convert arg_strings into positional actions
def consume_positionals(start_index):
# match as many Positionals as possible
match_partial = self._match_arguments_partial
selected_pattern = arg_strings_pattern[start_index:]
arg_counts = match_partial(positionals, selected_pattern)
# slice off the appropriate arg strings for each Positional
# and add the Positional and its args to the list
for action, arg_count in zip(positionals, arg_counts):
args = arg_strings[start_index: start_index + arg_count]
start_index += arg_count
take_action(action, args)
# slice off the Positionals that we just parsed and return the
# index at which the Positionals' string args stopped
positionals[:] = positionals[len(arg_counts):]
return start_index
# consume Positionals and Optionals alternately, until we have
# passed the last option string
extras = []
start_index = 0
if option_string_indices:
max_option_string_index = max(option_string_indices)
else:
max_option_string_index = -1
while start_index <= max_option_string_index:
# consume any Positionals preceding the next option
next_option_string_index = min([
index
for index in option_string_indices
if index >= start_index])
if start_index != next_option_string_index:
positionals_end_index = consume_positionals(start_index)
# only try to parse the next optional if we didn't consume
# the option string during the positionals parsing
if positionals_end_index > start_index:
start_index = positionals_end_index
continue
else:
start_index = positionals_end_index
# if we consumed all the positionals we could and we're not
# at the index of an option string, there were extra arguments
if start_index not in option_string_indices:
strings = arg_strings[start_index:next_option_string_index]
extras.extend(strings)
start_index = next_option_string_index
# consume the next optional and any arguments for it
start_index = consume_optional(start_index)
# consume any positionals following the last Optional
stop_index = consume_positionals(start_index)
# if we didn't consume all the argument strings, there were extras
extras.extend(arg_strings[stop_index:])
# if we didn't use all the Positional objects, there were too few
# arg strings supplied.
if positionals:
self.error(_('too few arguments'))
# make sure all required actions were present, and convert defaults.
for action in self._actions:
if action not in seen_actions:
if action.required:
name = _get_action_name(action)
self.error(_('argument %s is required') % name)
else:
# Convert action default now instead of doing it before
# parsing arguments to avoid calling convert functions
# twice (which may fail) if the argument was given, but
# only if it was defined already in the namespace
if (action.default is not None and
isinstance(action.default, basestring) and
hasattr(namespace, action.dest) and
action.default is getattr(namespace, action.dest)):
setattr(namespace, action.dest,
self._get_value(action, action.default))
# make sure all required groups had one option present
for group in self._mutually_exclusive_groups:
if group.required:
for action in group._group_actions:
if action in seen_non_default_actions:
break
# if no actions were used, report the error
else:
names = [_get_action_name(action)
for action in group._group_actions
if action.help is not SUPPRESS]
msg = _('one of the arguments %s is required')
self.error(msg % ' '.join(names))
# return the updated namespace and the extra arguments
return namespace, extras
def _read_args_from_files(self, arg_strings):
# expand arguments referencing files
new_arg_strings = []
for arg_string in arg_strings:
# for regular arguments, just add them back into the list
if not arg_string or arg_string[0] not in self.fromfile_prefix_chars:
new_arg_strings.append(arg_string)
# replace arguments referencing files with the file content
else:
try:
args_file = open(arg_string[1:])
try:
arg_strings = []
for arg_line in args_file.read().splitlines():
for arg in self.convert_arg_line_to_args(arg_line):
arg_strings.append(arg)
arg_strings = self._read_args_from_files(arg_strings)
new_arg_strings.extend(arg_strings)
finally:
args_file.close()
except IOError:
err = _sys.exc_info()[1]
self.error(str(err))
# return the modified argument list
return new_arg_strings
def convert_arg_line_to_args(self, arg_line):
return [arg_line]
def _match_argument(self, action, arg_strings_pattern):
# match the pattern for this action to the arg strings
nargs_pattern = self._get_nargs_pattern(action)
match = _re.match(nargs_pattern, arg_strings_pattern)
# raise an exception if we weren't able to find a match
if match is None:
nargs_errors = {
None: _('expected one argument'),
OPTIONAL: _('expected at most one argument'),
ONE_OR_MORE: _('expected at least one argument'),
}
default = _('expected %s argument(s)') % action.nargs
msg = nargs_errors.get(action.nargs, default)
raise ArgumentError(action, msg)
# return the number of arguments matched
return len(match.group(1))
def _match_arguments_partial(self, actions, arg_strings_pattern):
# progressively shorten the actions list by slicing off the
# final actions until we find a match
result = []
for i in range(len(actions), 0, -1):
actions_slice = actions[:i]
pattern = ''.join([self._get_nargs_pattern(action)
for action in actions_slice])
match = _re.match(pattern, arg_strings_pattern)
if match is not None:
result.extend([len(string) for string in match.groups()])
break
# return the list of arg string counts
return result
def _parse_optional(self, arg_string):
# if it's an empty string, it was meant to be a positional
if not arg_string:
return None
# if it doesn't start with a prefix, it was meant to be positional
if not arg_string[0] in self.prefix_chars:
return None
# if the option string is present in the parser, return the action
if arg_string in self._option_string_actions:
action = self._option_string_actions[arg_string]
return action, arg_string, None
# if it's just a single character, it was meant to be positional
if len(arg_string) == 1:
return None
# if the option string before the "=" is present, return the action
if '=' in arg_string:
option_string, explicit_arg = arg_string.split('=', 1)
if option_string in self._option_string_actions:
action = self._option_string_actions[option_string]
return action, option_string, explicit_arg
# search through all possible prefixes of the option string
# and all actions in the parser for possible interpretations
option_tuples = self._get_option_tuples(arg_string)
# if multiple actions match, the option string was ambiguous
if len(option_tuples) > 1:
options = ', '.join([option_string
for action, option_string, explicit_arg in option_tuples])
tup = arg_string, options
self.error(_('ambiguous option: %s could match %s') % tup)
# if exactly one action matched, this segmentation is good,
# so return the parsed action
elif len(option_tuples) == 1:
option_tuple, = option_tuples
return option_tuple
# if it was not found as an option, but it looks like a negative
# number, it was meant to be positional
# unless there are negative-number-like options
if self._negative_number_matcher.match(arg_string):
if not self._has_negative_number_optionals:
return None
# if it contains a space, it was meant to be a positional
if ' ' in arg_string:
return None
# it was meant to be an optional but there is no such option
# in this parser (though it might be a valid option in a subparser)
return None, arg_string, None
def _get_option_tuples(self, option_string):
result = []
# option strings starting with two prefix characters are only
# split at the '='
chars = self.prefix_chars
if option_string[0] in chars and option_string[1] in chars:
if '=' in option_string:
option_prefix, explicit_arg = option_string.split('=', 1)
else:
option_prefix = option_string
explicit_arg = None
for option_string in self._option_string_actions:
if option_string.startswith(option_prefix):
action = self._option_string_actions[option_string]
tup = action, option_string, explicit_arg
result.append(tup)
# single character options can be concatenated with their arguments
# but multiple character options always have to have their argument
# separate
elif option_string[0] in chars and option_string[1] not in chars:
option_prefix = option_string
explicit_arg = None
short_option_prefix = option_string[:2]
short_explicit_arg = option_string[2:]
for option_string in self._option_string_actions:
if option_string == short_option_prefix:
action = self._option_string_actions[option_string]
tup = action, option_string, short_explicit_arg
result.append(tup)
elif option_string.startswith(option_prefix):
action = self._option_string_actions[option_string]
tup = action, option_string, explicit_arg
result.append(tup)
# shouldn't ever get here
else:
self.error(_('unexpected option string: %s') % option_string)
# return the collected option tuples
return result
def _get_nargs_pattern(self, action):
# in all examples below, we have to allow for '--' args
# which are represented as '-' in the pattern
nargs = action.nargs
# the default (None) is assumed to be a single argument
if nargs is None:
nargs_pattern = '(-*A-*)'
# allow zero or one arguments
elif nargs == OPTIONAL:
nargs_pattern = '(-*A?-*)'
# allow zero or more arguments
elif nargs == ZERO_OR_MORE:
nargs_pattern = '(-*[A-]*)'
# allow one or more arguments
elif nargs == ONE_OR_MORE:
nargs_pattern = '(-*A[A-]*)'
# allow any number of options or arguments
elif nargs == REMAINDER:
nargs_pattern = '([-AO]*)'
# allow one argument followed by any number of options or arguments
elif nargs == PARSER:
nargs_pattern = '(-*A[-AO]*)'
# all others should be integers
else:
nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs)
# if this is an optional action, -- is not allowed
if action.option_strings:
nargs_pattern = nargs_pattern.replace('-*', '')
nargs_pattern = nargs_pattern.replace('-', '')
# return the pattern
return nargs_pattern
# ========================
# Value conversion methods
# ========================
def _get_values(self, action, arg_strings):
# for everything but PARSER, REMAINDER args, strip out first '--'
if action.nargs not in [PARSER, REMAINDER]:
try:
arg_strings.remove('--')
except ValueError:
pass
# optional argument produces a default when not present
if not arg_strings and action.nargs == OPTIONAL:
if action.option_strings:
value = action.const
else:
value = action.default
if isinstance(value, basestring):
value = self._get_value(action, value)
self._check_value(action, value)
# when nargs='*' on a positional, if there were no command-line
# args, use the default if it is anything other than None
elif (not arg_strings and action.nargs == ZERO_OR_MORE and
not action.option_strings):
if action.default is not None:
value = action.default
else:
value = arg_strings
self._check_value(action, value)
# single argument or optional argument produces a single value
elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]:
arg_string, = arg_strings
value = self._get_value(action, arg_string)
self._check_value(action, value)
# REMAINDER arguments convert all values, checking none
elif action.nargs == REMAINDER:
value = [self._get_value(action, v) for v in arg_strings]
# PARSER arguments convert all values, but check only the first
elif action.nargs == PARSER:
value = [self._get_value(action, v) for v in arg_strings]
self._check_value(action, value[0])
# all other types of nargs produce a list
else:
value = [self._get_value(action, v) for v in arg_strings]
for v in value:
self._check_value(action, v)
# return the converted value
return value
def _get_value(self, action, arg_string):
type_func = self._registry_get('type', action.type, action.type)
if not _callable(type_func):
msg = _('%r is not callable')
raise ArgumentError(action, msg % type_func)
# convert the value to the appropriate type
try:
result = type_func(arg_string)
# ArgumentTypeErrors indicate errors
except ArgumentTypeError:
name = getattr(action.type, '__name__', repr(action.type))
msg = str(_sys.exc_info()[1])
raise ArgumentError(action, msg)
# TypeErrors or ValueErrors also indicate errors
except (TypeError, ValueError):
name = getattr(action.type, '__name__', repr(action.type))
msg = _('invalid %s value: %r')
raise ArgumentError(action, msg % (name, arg_string))
# return the converted value
return result
def _check_value(self, action, value):
# converted value must be one of the choices (if specified)
if action.choices is not None and value not in action.choices:
tup = value, ', '.join(map(repr, action.choices))
msg = _('invalid choice: %r (choose from %s)') % tup
raise ArgumentError(action, msg)
# =======================
# Help-formatting methods
# =======================
def format_usage(self):
formatter = self._get_formatter()
formatter.add_usage(self.usage, self._actions,
self._mutually_exclusive_groups)
return formatter.format_help()
def format_help(self):
formatter = self._get_formatter()
# usage
formatter.add_usage(self.usage, self._actions,
self._mutually_exclusive_groups)
# description
formatter.add_text(self.description)
# positionals, optionals and user-defined groups
for action_group in self._action_groups:
formatter.start_section(action_group.title)
formatter.add_text(action_group.description)
formatter.add_arguments(action_group._group_actions)
formatter.end_section()
# epilog
formatter.add_text(self.epilog)
# determine help from format above
return formatter.format_help()
def format_version(self):
import warnings
warnings.warn(
'The format_version method is deprecated -- the "version" '
'argument to ArgumentParser is no longer supported.',
DeprecationWarning)
formatter = self._get_formatter()
formatter.add_text(self.version)
return formatter.format_help()
def _get_formatter(self):
return self.formatter_class(prog=self.prog)
# =====================
# Help-printing methods
# =====================
def print_usage(self, file=None):
if file is None:
file = _sys.stdout
self._print_message(self.format_usage(), file)
def print_help(self, file=None):
if file is None:
file = _sys.stdout
self._print_message(self.format_help(), file)
def print_version(self, file=None):
import warnings
warnings.warn(
'The print_version method is deprecated -- the "version" '
'argument to ArgumentParser is no longer supported.',
DeprecationWarning)
self._print_message(self.format_version(), file)
def _print_message(self, message, file=None):
if message:
if file is None:
file = _sys.stderr
file.write(message)
# ===============
# Exiting methods
# ===============
def exit(self, status=0, message=None):
if message:
self._print_message(message, _sys.stderr)
_sys.exit(status)
def error(self, message):
"""error(message: string)
Prints a usage message incorporating the message to stderr and
exits.
If you override this in a subclass, it should not return -- it
should either exit or raise an exception.
"""
self.print_usage(_sys.stderr)
self.exit(2, _('%s: error: %s\n') % (self.prog, message))
# -*- coding: utf-8 -*-
"""
ast
~~~
The `ast` module helps Python applications to process trees of the Python
abstract syntax grammar. The abstract syntax itself might change with
each Python release; this module helps to find out programmatically what
the current grammar looks like and allows modifications of it.
An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
a flag to the `compile()` builtin function or by using the `parse()`
function from this module. The result will be a tree of objects whose
classes all inherit from `ast.AST`.
A modified abstract syntax tree can be compiled into a Python code object
using the built-in `compile()` function.
Additionally various helper functions are provided that make working with
the trees simpler. The main intention of the helper functions and this
module in general is to provide an easy to use interface for libraries
that work tightly with the python syntax (template engines for example).
:copyright: Copyright 2008 by Armin Ronacher.
:license: Python License.
"""
from _ast import *
from _ast import __version__
def parse(source, filename='<unknown>', mode='exec'):
"""
Parse the source into an AST node.
Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
"""
return compile(source, filename, mode, PyCF_ONLY_AST)
def literal_eval(node_or_string):
"""
Safely evaluate an expression node or a string containing a Python
expression. The string or node provided may only consist of the following
Python literal structures: strings, numbers, tuples, lists, dicts, booleans,
and None.
"""
_safe_names = {'None': None, 'True': True, 'False': False}
if isinstance(node_or_string, basestring):
node_or_string = parse(node_or_string, mode='eval')
if isinstance(node_or_string, Expression):
node_or_string = node_or_string.body
def _convert(node):
if isinstance(node, Str):
return node.s
elif isinstance(node, Num):
return node.n
elif isinstance(node, Tuple):
return tuple(map(_convert, node.elts))
elif isinstance(node, List):
return list(map(_convert, node.elts))
elif isinstance(node, Dict):
return dict((_convert(k), _convert(v)) for k, v
in zip(node.keys, node.values))
elif isinstance(node, Name):
if node.id in _safe_names:
return _safe_names[node.id]
elif isinstance(node, BinOp) and \
isinstance(node.op, (Add, Sub)) and \
isinstance(node.right, Num) and \
isinstance(node.right.n, complex) and \
isinstance(node.left, Num) and \
isinstance(node.left.n, (int, long, float)):
left = node.left.n
right = node.right.n
if isinstance(node.op, Add):
return left + right
else:
return left - right
raise ValueError('malformed string')
return _convert(node_or_string)
def dump(node, annotate_fields=True, include_attributes=False):
"""
Return a formatted dump of the tree in *node*. This is mainly useful for
debugging purposes. The returned string will show the names and the values
for fields. This makes the code impossible to evaluate, so if evaluation is
wanted *annotate_fields* must be set to False. Attributes such as line
numbers and column offsets are not dumped by default. If this is wanted,
*include_attributes* can be set to True.
"""
def _format(node):
if isinstance(node, AST):
fields = [(a, _format(b)) for a, b in iter_fields(node)]
rv = '%s(%s' % (node.__class__.__name__, ', '.join(
('%s=%s' % field for field in fields)
if annotate_fields else
(b for a, b in fields)
))
if include_attributes and node._attributes:
rv += fields and ', ' or ' '
rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
for a in node._attributes)
return rv + ')'
elif isinstance(node, list):
return '[%s]' % ', '.join(_format(x) for x in node)
return repr(node)
if not isinstance(node, AST):
raise TypeError('expected AST, got %r' % node.__class__.__name__)
return _format(node)
def copy_location(new_node, old_node):
"""
Copy source location (`lineno` and `col_offset` attributes) from
*old_node* to *new_node* if possible, and return *new_node*.
"""
for attr in 'lineno', 'col_offset':
if attr in old_node._attributes and attr in new_node._attributes \
and hasattr(old_node, attr):
setattr(new_node, attr, getattr(old_node, attr))
return new_node
def fix_missing_locations(node):
"""
When you compile a node tree with compile(), the compiler expects lineno and
col_offset attributes for every node that supports them. This is rather
tedious to fill in for generated nodes, so this helper adds these attributes
recursively where not already set, by setting them to the values of the
parent node. It works recursively starting at *node*.
"""
def _fix(node, lineno, col_offset):
if 'lineno' in node._attributes:
if not hasattr(node, 'lineno'):
node.lineno = lineno
else:
lineno = node.lineno
if 'col_offset' in node._attributes:
if not hasattr(node, 'col_offset'):
node.col_offset = col_offset
else:
col_offset = node.col_offset
for child in iter_child_nodes(node):
_fix(child, lineno, col_offset)
_fix(node, 1, 0)
return node
def increment_lineno(node, n=1):
"""
Increment the line number of each node in the tree starting at *node* by *n*.
This is useful to "move code" to a different location in a file.
"""
for child in walk(node):
if 'lineno' in child._attributes:
child.lineno = getattr(child, 'lineno', 0) + n
return node
def iter_fields(node):
"""
Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
that is present on *node*.
"""
for field in node._fields:
try:
yield field, getattr(node, field)
except AttributeError:
pass
def iter_child_nodes(node):
"""
Yield all direct child nodes of *node*, that is, all fields that are nodes
and all items of fields that are lists of nodes.
"""
for name, field in iter_fields(node):
if isinstance(field, AST):
yield field
elif isinstance(field, list):
for item in field:
if isinstance(item, AST):
yield item
def get_docstring(node, clean=True):
"""
Return the docstring for the given node or None if no docstring can
be found. If the node provided does not have docstrings a TypeError
will be raised.
"""
if not isinstance(node, (FunctionDef, ClassDef, Module)):
raise TypeError("%r can't have docstrings" % node.__class__.__name__)
if node.body and isinstance(node.body[0], Expr) and \
isinstance(node.body[0].value, Str):
if clean:
import inspect
return inspect.cleandoc(node.body[0].value.s)
return node.body[0].value.s
def walk(node):
"""
Recursively yield all descendant nodes in the tree starting at *node*
(including *node* itself), in no specified order. This is useful if you
only want to modify nodes in place and don't care about the context.
"""
from collections import deque
todo = deque([node])
while todo:
node = todo.popleft()
todo.extend(iter_child_nodes(node))
yield node
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
visitor function for every node found. This function may return a value
which is forwarded by the `visit` method.
This class is meant to be subclassed, with the subclass adding visitor
methods.
Per default the visitor functions for the nodes are ``'visit_'`` +
class name of the node. So a `TryFinally` node visit function would
be `visit_TryFinally`. This behavior can be changed by overriding
the `visit` method. If no visitor function exists for a node
(return value `None`) the `generic_visit` visitor is used instead.
Don't use the `NodeVisitor` if you want to apply changes to nodes during
traversing. For this a special visitor exists (`NodeTransformer`) that
allows modifications.
"""
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)
class NodeTransformer(NodeVisitor):
"""
A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
allows modification of nodes.
The `NodeTransformer` will walk the AST and use the return value of the
visitor methods to replace or remove the old node. If the return value of
the visitor method is ``None``, the node will be removed from its location,
otherwise it is replaced with the return value. The return value may be the
original node in which case no replacement takes place.
Here is an example transformer that rewrites all occurrences of name lookups
(``foo``) to ``data['foo']``::
class RewriteName(NodeTransformer):
def visit_Name(self, node):
return copy_location(Subscript(
value=Name(id='data', ctx=Load()),
slice=Index(value=Str(s=node.id)),
ctx=node.ctx
), node)
Keep in mind that if the node you're operating on has child nodes you must
either transform the child nodes yourself or call the :meth:`generic_visit`
method for the node first.
For nodes that were part of a collection of statements (that applies to all
statement nodes), the visitor may also return a list of nodes rather than
just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
"""
def generic_visit(self, node):
for field, old_value in iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
# -*- Mode: Python; tab-width: 4 -*-
# Id: asynchat.py,v 2.26 2000/09/07 22:29:26 rushing Exp
# Author: Sam Rushing <[email protected]>
# ======================================================================
# Copyright 1996 by Sam Rushing
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software and
# its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of Sam
# Rushing not be used in advertising or publicity pertaining to
# distribution of the software without specific, written prior
# permission.
#
# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# ======================================================================
r"""A class supporting chat-style (command/response) protocols.
This class adds support for 'chat' style protocols - where one side
sends a 'command', and the other sends a response (examples would be
the common internet protocols - smtp, nntp, ftp, etc..).
The handle_read() method looks at the input stream for the current
'terminator' (usually '\r\n' for single-line responses, '\r\n.\r\n'
for multi-line output), calling self.found_terminator() on its
receipt.
for example:
Say you build an async nntp client using this class. At the start
of the connection, you'll have self.terminator set to '\r\n', in
order to process the single-line greeting. Just before issuing a
'LIST' command you'll set it to '\r\n.\r\n'. The output of the LIST
command will be accumulated (using your own 'collect_incoming_data'
method) up to the terminator, and then control will be returned to
you - by calling your self.found_terminator() method.
"""
import asyncore
import errno
import socket
from collections import deque
from sys import py3kwarning
from warnings import filterwarnings, catch_warnings
_BLOCKING_IO_ERRORS = (errno.EAGAIN, errno.EALREADY, errno.EINPROGRESS,
errno.EWOULDBLOCK)
class async_chat (asyncore.dispatcher):
"""This is an abstract class. You must derive from this class, and add
the two methods collect_incoming_data() and found_terminator()"""
# these are overridable defaults
ac_in_buffer_size = 4096
ac_out_buffer_size = 4096
def __init__ (self, sock=None, map=None):
# for string terminator matching
self.ac_in_buffer = ''
# we use a list here rather than cStringIO for a few reasons...
# del lst[:] is faster than sio.truncate(0)
# lst = [] is faster than sio.truncate(0)
# cStringIO will be gaining unicode support in py3k, which
# will negatively affect the performance of bytes compared to
# a ''.join() equivalent
self.incoming = []
# we toss the use of the "simple producer" and replace it with
# a pure deque, which the original fifo was a wrapping of
self.producer_fifo = deque()
asyncore.dispatcher.__init__ (self, sock, map)
def collect_incoming_data(self, data):
raise NotImplementedError("must be implemented in subclass")
def _collect_incoming_data(self, data):
self.incoming.append(data)
def _get_data(self):
d = ''.join(self.incoming)
del self.incoming[:]
return d
def found_terminator(self):
raise NotImplementedError("must be implemented in subclass")
def set_terminator (self, term):
"Set the input delimiter. Can be a fixed string of any length, an integer, or None"
self.terminator = term
def get_terminator (self):
return self.terminator
# grab some more data from the socket,
# throw it to the collector method,
# check for the terminator,
# if found, transition to the next state.
def handle_read (self):
try:
data = self.recv (self.ac_in_buffer_size)
except socket.error, why:
if why.args[0] in _BLOCKING_IO_ERRORS:
return
self.handle_error()
return
self.ac_in_buffer = self.ac_in_buffer + data
# Continue to search for self.terminator in self.ac_in_buffer,
# while calling self.collect_incoming_data. The while loop
# is necessary because we might read several data+terminator
# combos with a single recv(4096).
while self.ac_in_buffer:
lb = len(self.ac_in_buffer)
terminator = self.get_terminator()
if not terminator:
# no terminator, collect it all
self.collect_incoming_data (self.ac_in_buffer)
self.ac_in_buffer = ''
elif isinstance(terminator, (int, long)):
# numeric terminator
n = terminator
if lb < n:
self.collect_incoming_data (self.ac_in_buffer)
self.ac_in_buffer = ''
self.terminator = self.terminator - lb
else:
self.collect_incoming_data (self.ac_in_buffer[:n])
self.ac_in_buffer = self.ac_in_buffer[n:]
self.terminator = 0
self.found_terminator()
else:
# 3 cases:
# 1) end of buffer matches terminator exactly:
# collect data, transition
# 2) end of buffer matches some prefix:
# collect data to the prefix
# 3) end of buffer does not match any prefix:
# collect data
terminator_len = len(terminator)
index = self.ac_in_buffer.find(terminator)
if index != -1:
# we found the terminator
if index > 0:
# don't bother reporting the empty string (source of subtle bugs)
self.collect_incoming_data (self.ac_in_buffer[:index])
self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
# This does the Right Thing if the terminator is changed here.
self.found_terminator()
else:
# check for a prefix of the terminator
index = find_prefix_at_end (self.ac_in_buffer, terminator)
if index:
if index != lb:
# we found a prefix, collect up to the prefix
self.collect_incoming_data (self.ac_in_buffer[:-index])
self.ac_in_buffer = self.ac_in_buffer[-index:]
break
else:
# no prefix, collect it all
self.collect_incoming_data (self.ac_in_buffer)
self.ac_in_buffer = ''
def handle_write (self):
self.initiate_send()
def handle_close (self):
self.close()
def push (self, data):
sabs = self.ac_out_buffer_size
if len(data) > sabs:
for i in xrange(0, len(data), sabs):
self.producer_fifo.append(data[i:i+sabs])
else:
self.producer_fifo.append(data)
self.initiate_send()
def push_with_producer (self, producer):
self.producer_fifo.append(producer)
self.initiate_send()
def readable (self):
"predicate for inclusion in the readable for select()"
# cannot use the old predicate, it violates the claim of the
# set_terminator method.
# return (len(self.ac_in_buffer) <= self.ac_in_buffer_size)
return 1
def writable (self):
"predicate for inclusion in the writable for select()"
return self.producer_fifo or (not self.connected)
def close_when_done (self):
"automatically close this channel once the outgoing queue is empty"
self.producer_fifo.append(None)
def initiate_send(self):
while self.producer_fifo and self.connected:
first = self.producer_fifo[0]
# handle empty string/buffer or None entry
if not first:
del self.producer_fifo[0]
if first is None:
self.handle_close()
return
# handle classic producer behavior
obs = self.ac_out_buffer_size
try:
with catch_warnings():
if py3kwarning:
filterwarnings("ignore", ".*buffer", DeprecationWarning)
data = buffer(first, 0, obs)
except TypeError:
data = first.more()
if data:
self.producer_fifo.appendleft(data)
else:
del self.producer_fifo[0]
continue
# send the data
try:
num_sent = self.send(data)
except socket.error:
self.handle_error()
return
if num_sent:
if num_sent < len(data) or obs < len(first):
self.producer_fifo[0] = first[num_sent:]
else:
del self.producer_fifo[0]
# we tried to send some actual data
return
def discard_buffers (self):
# Emergencies only!
self.ac_in_buffer = ''
del self.incoming[:]
self.producer_fifo.clear()
class simple_producer:
def __init__ (self, data, buffer_size=512):
self.data = data
self.buffer_size = buffer_size
def more (self):
if len (self.data) > self.buffer_size:
result = self.data[:self.buffer_size]
self.data = self.data[self.buffer_size:]
return result
else:
result = self.data
self.data = ''
return result
class fifo:
def __init__ (self, list=None):
if not list:
self.list = deque()
else:
self.list = deque(list)
def __len__ (self):
return len(self.list)
def is_empty (self):
return not self.list
def first (self):
return self.list[0]
def push (self, data):
self.list.append(data)
def pop (self):
if self.list:
return (1, self.list.popleft())
else:
return (0, None)
# Given 'haystack', see if any prefix of 'needle' is at its end. This
# assumes an exact match has already been checked. Return the number of
# characters matched.
# for example:
# f_p_a_e ("qwerty\r", "\r\n") => 1
# f_p_a_e ("qwertydkjf", "\r\n") => 0
# f_p_a_e ("qwerty\r\n", "\r\n") => <undefined>
# this could maybe be made faster with a computed regex?
# [answer: no; circa Python-2.0, Jan 2001]
# new python: 28961/s
# old python: 18307/s
# re: 12820/s
# regex: 14035/s
def find_prefix_at_end (haystack, needle):
l = len(needle) - 1
while l and not haystack.endswith(needle[:l]):
l -= 1
return l
# -*- Mode: Python -*-
# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp
# Author: Sam Rushing <[email protected]>
# ======================================================================
# Copyright 1996 by Sam Rushing
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software and
# its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of Sam
# Rushing not be used in advertising or publicity pertaining to
# distribution of the software without specific, written prior
# permission.
#
# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# ======================================================================
"""Basic infrastructure for asynchronous socket service clients and servers.
There are only two ways to have a program on a single processor do "more
than one thing at a time". Multi-threaded programming is the simplest and
most popular way to do it, but there is another very different technique,
that lets you have nearly all the advantages of multi-threading, without
actually using multiple threads. it's really only practical if your program
is largely I/O bound. If your program is CPU bound, then pre-emptive
scheduled threads are probably what you really need. Network servers are
rarely CPU-bound, however.
If your operating system supports the select() system call in its I/O
library (and nearly all do), then you can use it to juggle multiple
communication channels at once; doing other work while your I/O is taking
place in the "background." Although this strategy can seem strange and
complex, especially at first, it is in many ways easier to understand and
control than multi-threaded programming. The module documented here solves
many of the difficult problems for you, making the task of building
sophisticated high-performance network servers and clients a snap.
"""
import select
import socket
import sys
import time
import warnings
import os
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \
ENOTCONN, ESHUTDOWN, EINTR, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \
errorcode
_DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE,
EBADF))
try:
socket_map
except NameError:
socket_map = {}
def _strerror(err):
try:
return os.strerror(err)
except (ValueError, OverflowError, NameError):
if err in errorcode:
return errorcode[err]
return "Unknown error %s" %err
class ExitNow(Exception):
pass
_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit)
def read(obj):
try:
obj.handle_read_event()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def write(obj):
try:
obj.handle_write_event()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def _exception(obj):
try:
obj.handle_expt_event()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def readwrite(obj, flags):
try:
if flags & select.POLLIN:
obj.handle_read_event()
if flags & select.POLLOUT:
obj.handle_write_event()
if flags & select.POLLPRI:
obj.handle_expt_event()
if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL):
obj.handle_close()
except socket.error, e:
if e.args[0] not in _DISCONNECTED:
obj.handle_error()
else:
obj.handle_close()
except _reraised_exceptions:
raise
except:
obj.handle_error()
def poll(timeout=0.0, map=None):
if map is None:
map = socket_map
if map:
r = []; w = []; e = []
for fd, obj in map.items():
is_r = obj.readable()
is_w = obj.writable()
if is_r:
r.append(fd)
# accepting sockets should not be writable
if is_w and not obj.accepting:
w.append(fd)
if is_r or is_w:
e.append(fd)
if [] == r == w == e:
time.sleep(timeout)
return
try:
r, w, e = select.select(r, w, e, timeout)
except select.error, err:
if err.args[0] != EINTR:
raise
else:
return
for fd in r:
obj = map.get(fd)
if obj is None:
continue
read(obj)
for fd in w:
obj = map.get(fd)
if obj is None:
continue
write(obj)
for fd in e:
obj = map.get(fd)
if obj is None:
continue
_exception(obj)
def poll2(timeout=0.0, map=None):
# Use the poll() support added to the select module in Python 2.0
if map is None:
map = socket_map
if timeout is not None:
# timeout is in milliseconds
timeout = int(timeout*1000)
pollster = select.poll()
if map:
for fd, obj in map.items():
flags = 0
if obj.readable():
flags |= select.POLLIN | select.POLLPRI
# accepting sockets should not be writable
if obj.writable() and not obj.accepting:
flags |= select.POLLOUT
if flags:
# Only check for exceptions if object was either readable
# or writable.
flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL
pollster.register(fd, flags)
try:
r = pollster.poll(timeout)
except select.error, err:
if err.args[0] != EINTR:
raise
r = []
for fd, flags in r:
obj = map.get(fd)
if obj is None:
continue
readwrite(obj, flags)
poll3 = poll2 # Alias for backward compatibility
def loop(timeout=30.0, use_poll=False, map=None, count=None):
if map is None:
map = socket_map
if use_poll and hasattr(select, 'poll'):
poll_fun = poll2
else:
poll_fun = poll
if count is None:
while map:
poll_fun(timeout, map)
else:
while map and count > 0:
poll_fun(timeout, map)
count = count - 1
class dispatcher:
debug = False
connected = False
accepting = False
connecting = False
closing = False
addr = None
ignore_log_types = frozenset(['warning'])
def __init__(self, sock=None, map=None):
if map is None:
self._map = socket_map
else:
self._map = map
self._fileno = None
if sock:
# Set to nonblocking just to make sure for cases where we
# get a socket from a blocking source.
sock.setblocking(0)
self.set_socket(sock, map)
self.connected = True
# The constructor no longer requires that the socket
# passed be connected.
try:
self.addr = sock.getpeername()
except socket.error, err:
if err.args[0] in (ENOTCONN, EINVAL):
# To handle the case where we got an unconnected
# socket.
self.connected = False
else:
# The socket is broken in some unknown way, alert
# the user and remove it from the map (to prevent
# polling of broken sockets).
self.del_channel(map)
raise
else:
self.socket = None
def __repr__(self):
status = [self.__class__.__module__+"."+self.__class__.__name__]
if self.accepting and self.addr:
status.append('listening')
elif self.connected:
status.append('connected')
if self.addr is not None:
try:
status.append('%s:%d' % self.addr)
except TypeError:
status.append(repr(self.addr))
return '<%s at %#x>' % (' '.join(status), id(self))
__str__ = __repr__
def add_channel(self, map=None):
#self.log_info('adding channel %s' % self)
if map is None:
map = self._map
map[self._fileno] = self
def del_channel(self, map=None):
fd = self._fileno
if map is None:
map = self._map
if fd in map:
#self.log_info('closing channel %d:%s' % (fd, self))
del map[fd]
self._fileno = None
def create_socket(self, family, type):
self.family_and_type = family, type
sock = socket.socket(family, type)
sock.setblocking(0)
self.set_socket(sock)
def set_socket(self, sock, map=None):
self.socket = sock
## self.__dict__['socket'] = sock
self._fileno = sock.fileno()
self.add_channel(map)
def set_reuse_addr(self):
# try to re-use a server port if possible
try:
self.socket.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR,
self.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR) | 1
)
except socket.error:
pass
# ==================================================
# predicates for select()
# these are used as filters for the lists of sockets
# to pass to select().
# ==================================================
def readable(self):
return True
def writable(self):
return True
# ==================================================
# socket object methods.
# ==================================================
def listen(self, num):
self.accepting = True
if os.name == 'nt' and num > 5:
num = 5
return self.socket.listen(num)
def bind(self, addr):
self.addr = addr
return self.socket.bind(addr)
def connect(self, address):
self.connected = False
self.connecting = True
err = self.socket.connect_ex(address)
if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \
or err == EINVAL and os.name in ('nt', 'ce'):
self.addr = address
return
if err in (0, EISCONN):
self.addr = address
self.handle_connect_event()
else:
raise socket.error(err, errorcode[err])
def accept(self):
# XXX can return either an address pair or None
try:
conn, addr = self.socket.accept()
except TypeError:
return None
except socket.error as why:
if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN):
return None
else:
raise
else:
return conn, addr
def send(self, data):
try:
result = self.socket.send(data)
return result
except socket.error, why:
if why.args[0] == EWOULDBLOCK:
return 0
elif why.args[0] in _DISCONNECTED:
self.handle_close()
return 0
else:
raise
def recv(self, buffer_size):
try:
data = self.socket.recv(buffer_size)
if not data:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.handle_close()
return ''
else:
return data
except socket.error, why:
# winsock sometimes raises ENOTCONN
if why.args[0] in _DISCONNECTED:
self.handle_close()
return ''
else:
raise
def close(self):
self.connected = False
self.accepting = False
self.connecting = False
self.del_channel()
try:
self.socket.close()
except socket.error, why:
if why.args[0] not in (ENOTCONN, EBADF):
raise
# cheap inheritance, used to pass all other attribute
# references to the underlying socket object.
def __getattr__(self, attr):
try:
retattr = getattr(self.socket, attr)
except AttributeError:
raise AttributeError("%s instance has no attribute '%s'"
%(self.__class__.__name__, attr))
else:
msg = "%(me)s.%(attr)s is deprecated. Use %(me)s.socket.%(attr)s " \
"instead." % {'me': self.__class__.__name__, 'attr':attr}
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return retattr
# log and log_info may be overridden to provide more sophisticated
# logging and warning methods. In general, log is for 'hit' logging
# and 'log_info' is for informational, warning and error logging.
def log(self, message):
sys.stderr.write('log: %s\n' % str(message))
def log_info(self, message, type='info'):
if type not in self.ignore_log_types:
print '%s: %s' % (type, message)
def handle_read_event(self):
if self.accepting:
# accepting sockets are never connected, they "spawn" new
# sockets that are connected
self.handle_accept()
elif not self.connected:
if self.connecting:
self.handle_connect_event()
self.handle_read()
else:
self.handle_read()
def handle_connect_event(self):
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
raise socket.error(err, _strerror(err))
self.handle_connect()
self.connected = True
self.connecting = False
def handle_write_event(self):
if self.accepting:
# Accepting sockets shouldn't get a write event.
# We will pretend it didn't happen.
return
if not self.connected:
if self.connecting:
self.handle_connect_event()
self.handle_write()
def handle_expt_event(self):
# handle_expt_event() is called if there might be an error on the
# socket, or if there is OOB data
# check for the error condition first
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
# we can get here when select.select() says that there is an
# exceptional condition on the socket
# since there is an error, we'll go ahead and close the socket
# like we would in a subclassed handle_read() that received no
# data
self.handle_close()
else:
self.handle_expt()
def handle_error(self):
nil, t, v, tbinfo = compact_traceback()
# sometimes a user repr method will crash.
try:
self_repr = repr(self)
except:
self_repr = '<__repr__(self) failed for object at %0x>' % id(self)
self.log_info(
'uncaptured python exception, closing channel %s (%s:%s %s)' % (
self_repr,
t,
v,
tbinfo
),
'error'
)
self.handle_close()
def handle_expt(self):
self.log_info('unhandled incoming priority event', 'warning')
def handle_read(self):
self.log_info('unhandled read event', 'warning')
def handle_write(self):
self.log_info('unhandled write event', 'warning')
def handle_connect(self):
self.log_info('unhandled connect event', 'warning')
def handle_accept(self):
self.log_info('unhandled accept event', 'warning')
def handle_close(self):
self.log_info('unhandled close event', 'warning')
self.close()
# ---------------------------------------------------------------------------
# adds simple buffered output capability, useful for simple clients.
# [for more sophisticated usage use asynchat.async_chat]
# ---------------------------------------------------------------------------
class dispatcher_with_send(dispatcher):
def __init__(self, sock=None, map=None):
dispatcher.__init__(self, sock, map)
self.out_buffer = ''
def initiate_send(self):
num_sent = 0
num_sent = dispatcher.send(self, self.out_buffer[:512])
self.out_buffer = self.out_buffer[num_sent:]
def handle_write(self):
self.initiate_send()
def writable(self):
return (not self.connected) or len(self.out_buffer)
def send(self, data):
if self.debug:
self.log_info('sending %s' % repr(data))
self.out_buffer = self.out_buffer + data
self.initiate_send()
# ---------------------------------------------------------------------------
# used for debugging.
# ---------------------------------------------------------------------------
def compact_traceback():
t, v, tb = sys.exc_info()
tbinfo = []
if not tb: # Must have a traceback
raise AssertionError("traceback does not exist")
while tb:
tbinfo.append((
tb.tb_frame.f_code.co_filename,
tb.tb_frame.f_code.co_name,
str(tb.tb_lineno)
))
tb = tb.tb_next
# just to be safe
del tb
file, function, line = tbinfo[-1]
info = ' '.join(['[%s|%s|%s]' % x for x in tbinfo])
return (file, function, line), t, v, info
def close_all(map=None, ignore_all=False):
if map is None:
map = socket_map
for x in map.values():
try:
x.close()
except OSError, x:
if x.args[0] == EBADF:
pass
elif not ignore_all:
raise
except _reraised_exceptions:
raise
except:
if not ignore_all:
raise
map.clear()
# Asynchronous File I/O:
#
# After a little research (reading man pages on various unixen, and
# digging through the linux kernel), I've determined that select()
# isn't meant for doing asynchronous file i/o.
# Heartening, though - reading linux/mm/filemap.c shows that linux
# supports asynchronous read-ahead. So _MOST_ of the time, the data
# will be sitting in memory for us already when we go to read it.
#
# What other OS's (besides NT) support async file i/o? [VMS?]
#
# Regardless, this is useful for pipes, and stdin/stdout...
if os.name == 'posix':
import fcntl
class file_wrapper:
# Here we override just enough to make a file
# look like a socket for the purposes of asyncore.
# The passed fd is automatically os.dup()'d
def __init__(self, fd):
self.fd = os.dup(fd)
def recv(self, *args):
return os.read(self.fd, *args)
def send(self, *args):
return os.write(self.fd, *args)
def getsockopt(self, level, optname, buflen=None):
if (level == socket.SOL_SOCKET and
optname == socket.SO_ERROR and
not buflen):
return 0
raise NotImplementedError("Only asyncore specific behaviour "
"implemented.")
read = recv
write = send
def close(self):
if self.fd < 0:
return
fd = self.fd
self.fd = -1
os.close(fd)
def fileno(self):
return self.fd
class file_dispatcher(dispatcher):
def __init__(self, fd, map=None):
dispatcher.__init__(self, None, map)
self.connected = True
try:
fd = fd.fileno()
except AttributeError:
pass
self.set_file(fd)
# set it to non-blocking mode
flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
flags = flags | os.O_NONBLOCK
fcntl.fcntl(fd, fcntl.F_SETFL, flags)
def set_file(self, fd):
self.socket = file_wrapper(fd)
self._fileno = self.socket.fileno()
self.add_channel()
"""
atexit.py - allow programmer to define multiple exit functions to be executed
upon normal program termination.
One public function, register, is defined.
"""
__all__ = ["register"]
import sys
_exithandlers = []
def _run_exitfuncs():
"""run any registered exit functions
_exithandlers is traversed in reverse order so functions are executed
last in, first out.
"""
exc_info = None
while _exithandlers:
func, targs, kargs = _exithandlers.pop()
try:
func(*targs, **kargs)
except SystemExit:
exc_info = sys.exc_info()
except:
import traceback
print >> sys.stderr, "Error in atexit._run_exitfuncs:"
traceback.print_exc()
exc_info = sys.exc_info()
if exc_info is not None:
raise exc_info[0], exc_info[1], exc_info[2]
def register(func, *targs, **kargs):
"""register a function to be executed upon normal program termination
func - function to be called at exit
targs - optional arguments to pass to func
kargs - optional keyword arguments to pass to func
func is returned to facilitate usage as a decorator.
"""
_exithandlers.append((func, targs, kargs))
return func
if hasattr(sys, "exitfunc"):
# Assume it's another registered exit function - append it to our list
register(sys.exitfunc)
sys.exitfunc = _run_exitfuncs
if __name__ == "__main__":
def x1():
print "running x1"
def x2(n):
print "running x2(%r)" % (n,)
def x3(n, kwd=None):
print "running x3(%r, kwd=%r)" % (n, kwd)
register(x1)
register(x2, 12)
register(x3, 5, "bar")
register(x3, "no kwd args")
"""Classes for manipulating audio devices (currently only for Sun and SGI)"""
from warnings import warnpy3k
warnpy3k("the audiodev module has been removed in Python 3.0", stacklevel=2)
del warnpy3k
__all__ = ["error","AudioDev"]
class error(Exception):
pass
class Play_Audio_sgi:
# Private instance variables
## if 0: access frameratelist, nchannelslist, sampwidthlist, oldparams, \
## params, config, inited_outrate, inited_width, \
## inited_nchannels, port, converter, classinited: private
classinited = 0
frameratelist = nchannelslist = sampwidthlist = None
def initclass(self):
import AL
self.frameratelist = [
(48000, AL.RATE_48000),
(44100, AL.RATE_44100),
(32000, AL.RATE_32000),
(22050, AL.RATE_22050),
(16000, AL.RATE_16000),
(11025, AL.RATE_11025),
( 8000, AL.RATE_8000),
]
self.nchannelslist = [
(1, AL.MONO),
(2, AL.STEREO),
(4, AL.QUADRO),
]
self.sampwidthlist = [
(1, AL.SAMPLE_8),
(2, AL.SAMPLE_16),
(3, AL.SAMPLE_24),
]
self.classinited = 1
def __init__(self):
import al, AL
if not self.classinited:
self.initclass()
self.oldparams = []
self.params = [AL.OUTPUT_RATE, 0]
self.config = al.newconfig()
self.inited_outrate = 0
self.inited_width = 0
self.inited_nchannels = 0
self.converter = None
self.port = None
return
def __del__(self):
if self.port:
self.stop()
if self.oldparams:
import al, AL
al.setparams(AL.DEFAULT_DEVICE, self.oldparams)
self.oldparams = []
def wait(self):
if not self.port:
return
import time
while self.port.getfilled() > 0:
time.sleep(0.1)
self.stop()
def stop(self):
if self.port:
self.port.closeport()
self.port = None
if self.oldparams:
import al, AL
al.setparams(AL.DEFAULT_DEVICE, self.oldparams)
self.oldparams = []
def setoutrate(self, rate):
for (raw, cooked) in self.frameratelist:
if rate == raw:
self.params[1] = cooked
self.inited_outrate = 1
break
else:
raise error, 'bad output rate'
def setsampwidth(self, width):
for (raw, cooked) in self.sampwidthlist:
if width == raw:
self.config.setwidth(cooked)
self.inited_width = 1
break
else:
if width == 0:
import AL
self.inited_width = 0
self.config.setwidth(AL.SAMPLE_16)
self.converter = self.ulaw2lin
else:
raise error, 'bad sample width'
def setnchannels(self, nchannels):
for (raw, cooked) in self.nchannelslist:
if nchannels == raw:
self.config.setchannels(cooked)
self.inited_nchannels = 1
break
else:
raise error, 'bad # of channels'
def writeframes(self, data):
if not (self.inited_outrate and self.inited_nchannels):
raise error, 'params not specified'
if not self.port:
import al, AL
self.port = al.openport('Python', 'w', self.config)
self.oldparams = self.params[:]
al.getparams(AL.DEFAULT_DEVICE, self.oldparams)
al.setparams(AL.DEFAULT_DEVICE, self.params)
if self.converter:
data = self.converter(data)
self.port.writesamps(data)
def getfilled(self):
if self.port:
return self.port.getfilled()
else:
return 0
def getfillable(self):
if self.port:
return self.port.getfillable()
else:
return self.config.getqueuesize()
# private methods
## if 0: access *: private
def ulaw2lin(self, data):
import audioop
return audioop.ulaw2lin(data, 2)
class Play_Audio_sun:
## if 0: access outrate, sampwidth, nchannels, inited_outrate, inited_width, \
## inited_nchannels, converter: private
def __init__(self):
self.outrate = 0
self.sampwidth = 0
self.nchannels = 0
self.inited_outrate = 0
self.inited_width = 0
self.inited_nchannels = 0
self.converter = None
self.port = None
return
def __del__(self):
self.stop()
def setoutrate(self, rate):
self.outrate = rate
self.inited_outrate = 1
def setsampwidth(self, width):
self.sampwidth = width
self.inited_width = 1
def setnchannels(self, nchannels):
self.nchannels = nchannels
self.inited_nchannels = 1
def writeframes(self, data):
if not (self.inited_outrate and self.inited_width and self.inited_nchannels):
raise error, 'params not specified'
if not self.port:
import sunaudiodev, SUNAUDIODEV
self.port = sunaudiodev.open('w')
info = self.port.getinfo()
info.o_sample_rate = self.outrate
info.o_channels = self.nchannels
if self.sampwidth == 0:
info.o_precision = 8
self.o_encoding = SUNAUDIODEV.ENCODING_ULAW
# XXX Hack, hack -- leave defaults
else:
info.o_precision = 8 * self.sampwidth
info.o_encoding = SUNAUDIODEV.ENCODING_LINEAR
self.port.setinfo(info)
if self.converter:
data = self.converter(data)
self.port.write(data)
def wait(self):
if not self.port:
return
self.port.drain()
self.stop()
def stop(self):
if self.port:
self.port.flush()
self.port.close()
self.port = None
def getfilled(self):
if self.port:
return self.port.obufcount()
else:
return 0
## # Nobody remembers what this method does, and it's broken. :-(
## def getfillable(self):
## return BUFFERSIZE - self.getfilled()
def AudioDev():
# Dynamically try to import and use a platform specific module.
try:
import al
except ImportError:
try:
import sunaudiodev
return Play_Audio_sun()
except ImportError:
try:
import Audio_mac
except ImportError:
raise error, 'no audio device'
else:
return Audio_mac.Play_Audio_mac()
else:
return Play_Audio_sgi()
def test(fn = None):
import sys
if sys.argv[1:]:
fn = sys.argv[1]
else:
fn = 'f:just samples:just.aif'
import aifc
af = aifc.open(fn, 'r')
print fn, af.getparams()
p = AudioDev()
p.setoutrate(af.getframerate())
p.setsampwidth(af.getsampwidth())
p.setnchannels(af.getnchannels())
BUFSIZ = af.getframerate()/af.getsampwidth()/af.getnchannels()
while 1:
data = af.readframes(BUFSIZ)
if not data: break
print len(data)
p.writeframes(data)
p.wait()
if __name__ == '__main__':
test()
#! /usr/bin/env python
"""RFC 3548: Base16, Base32, Base64 Data Encodings"""
# Modified 04-Oct-1995 by Jack Jansen to use binascii module
# Modified 30-Dec-2003 by Barry Warsaw to add full RFC 3548 support
import re
import struct
import string
import binascii
__all__ = [
# Legacy interface exports traditional RFC 1521 Base64 encodings
'encode', 'decode', 'encodestring', 'decodestring',
# Generalized interface for other encodings
'b64encode', 'b64decode', 'b32encode', 'b32decode',
'b16encode', 'b16decode',
# Standard Base64 encoding
'standard_b64encode', 'standard_b64decode',
# Some common Base64 alternatives. As referenced by RFC 3458, see thread
# starting at:
#
# http://zgp.org/pipermail/p2p-hackers/2001-September/000316.html
'urlsafe_b64encode', 'urlsafe_b64decode',
]
_translation = [chr(_x) for _x in range(256)]
EMPTYSTRING = ''
def _translate(s, altchars):
translation = _translation[:]
for k, v in altchars.items():
translation[ord(k)] = v
return s.translate(''.join(translation))
# Base64 encoding/decoding uses binascii
def b64encode(s, altchars=None):
"""Encode a string using Base64.
s is the string to encode. Optional altchars must be a string of at least
length 2 (additional characters are ignored) which specifies an
alternative alphabet for the '+' and '/' characters. This allows an
application to e.g. generate url or filesystem safe Base64 strings.
The encoded string is returned.
"""
# Strip off the trailing newline
encoded = binascii.b2a_base64(s)[:-1]
if altchars is not None:
return encoded.translate(string.maketrans(b'+/', altchars[:2]))
return encoded
def b64decode(s, altchars=None):
"""Decode a Base64 encoded string.
s is the string to decode. Optional altchars must be a string of at least
length 2 (additional characters are ignored) which specifies the
alternative alphabet used instead of the '+' and '/' characters.
The decoded string is returned. A TypeError is raised if s is
incorrectly padded. Characters that are neither in the normal base-64
alphabet nor the alternative alphabet are discarded prior to the padding
check.
"""
if altchars is not None:
s = s.translate(string.maketrans(altchars[:2], '+/'))
try:
return binascii.a2b_base64(s)
except binascii.Error, msg:
# Transform this exception for consistency
raise TypeError(msg)
def standard_b64encode(s):
"""Encode a string using the standard Base64 alphabet.
s is the string to encode. The encoded string is returned.
"""
return b64encode(s)
def standard_b64decode(s):
"""Decode a string encoded with the standard Base64 alphabet.
Argument s is the string to decode. The decoded string is returned. A
TypeError is raised if the string is incorrectly padded. Characters that
are not in the standard alphabet are discarded prior to the padding
check.
"""
return b64decode(s)
_urlsafe_encode_translation = string.maketrans(b'+/', b'-_')
_urlsafe_decode_translation = string.maketrans(b'-_', b'+/')
def urlsafe_b64encode(s):
"""Encode a string using the URL- and filesystem-safe Base64 alphabet.
Argument s is the string to encode. The encoded string is returned. The
alphabet uses '-' instead of '+' and '_' instead of '/'.
"""
return b64encode(s).translate(_urlsafe_encode_translation)
def urlsafe_b64decode(s):
"""Decode a string using the URL- and filesystem-safe Base64 alphabet.
Argument s is the string to decode. The decoded string is returned. A
TypeError is raised if the string is incorrectly padded. Characters that
are not in the URL-safe base-64 alphabet, and are not a plus '+' or slash
'/', are discarded prior to the padding check.
The alphabet uses '-' instead of '+' and '_' instead of '/'.
"""
return b64decode(s.translate(_urlsafe_decode_translation))
# Base32 encoding/decoding must be done in Python
_b32alphabet = {
0: 'A', 9: 'J', 18: 'S', 27: '3',
1: 'B', 10: 'K', 19: 'T', 28: '4',
2: 'C', 11: 'L', 20: 'U', 29: '5',
3: 'D', 12: 'M', 21: 'V', 30: '6',
4: 'E', 13: 'N', 22: 'W', 31: '7',
5: 'F', 14: 'O', 23: 'X',
6: 'G', 15: 'P', 24: 'Y',
7: 'H', 16: 'Q', 25: 'Z',
8: 'I', 17: 'R', 26: '2',
}
_b32tab = _b32alphabet.items()
_b32tab.sort()
_b32tab = [v for k, v in _b32tab]
_b32rev = dict([(v, long(k)) for k, v in _b32alphabet.items()])
def b32encode(s):
"""Encode a string using Base32.
s is the string to encode. The encoded string is returned.
"""
parts = []
quanta, leftover = divmod(len(s), 5)
# Pad the last quantum with zero bits if necessary
if leftover:
s += ('\0' * (5 - leftover))
quanta += 1
for i in range(quanta):
# c1 and c2 are 16 bits wide, c3 is 8 bits wide. The intent of this
# code is to process the 40 bits in units of 5 bits. So we take the 1
# leftover bit of c1 and tack it onto c2. Then we take the 2 leftover
# bits of c2 and tack them onto c3. The shifts and masks are intended
# to give us values of exactly 5 bits in width.
c1, c2, c3 = struct.unpack('!HHB', s[i*5:(i+1)*5])
c2 += (c1 & 1) << 16 # 17 bits wide
c3 += (c2 & 3) << 8 # 10 bits wide
parts.extend([_b32tab[c1 >> 11], # bits 1 - 5
_b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
_b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
_b32tab[c2 >> 12], # bits 16 - 20 (1 - 5)
_b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
_b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
_b32tab[c3 >> 5], # bits 31 - 35 (1 - 5)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5)
])
encoded = EMPTYSTRING.join(parts)
# Adjust for any leftover partial quanta
if leftover == 1:
return encoded[:-6] + '======'
elif leftover == 2:
return encoded[:-4] + '===='
elif leftover == 3:
return encoded[:-3] + '==='
elif leftover == 4:
return encoded[:-1] + '='
return encoded
def b32decode(s, casefold=False, map01=None):
"""Decode a Base32 encoded string.
s is the string to decode. Optional casefold is a flag specifying whether
a lowercase alphabet is acceptable as input. For security purposes, the
default is False.
RFC 3548 allows for optional mapping of the digit 0 (zero) to the letter O
(oh), and for optional mapping of the digit 1 (one) to either the letter I
(eye) or letter L (el). The optional argument map01 when not None,
specifies which letter the digit 1 should be mapped to (when map01 is not
None, the digit 0 is always mapped to the letter O). For security
purposes the default is None, so that 0 and 1 are not allowed in the
input.
The decoded string is returned. A TypeError is raised if s were
incorrectly padded or if there are non-alphabet characters present in the
string.
"""
quanta, leftover = divmod(len(s), 8)
if leftover:
raise TypeError('Incorrect padding')
# Handle section 2.4 zero and one mapping. The flag map01 will be either
# False, or the character to map the digit 1 (one) to. It should be
# either L (el) or I (eye).
if map01:
s = s.translate(string.maketrans(b'01', b'O' + map01))
if casefold:
s = s.upper()
# Strip off pad characters from the right. We need to count the pad
# characters because this will tell us how many null bytes to remove from
# the end of the decoded string.
padchars = 0
mo = re.search('(?P<pad>[=]*)$', s)
if mo:
padchars = len(mo.group('pad'))
if padchars > 0:
s = s[:-padchars]
# Now decode the full quanta
parts = []
acc = 0
shift = 35
for c in s:
val = _b32rev.get(c)
if val is None:
raise TypeError('Non-base32 digit found')
acc += _b32rev[c] << shift
shift -= 5
if shift < 0:
parts.append(binascii.unhexlify('%010x' % acc))
acc = 0
shift = 35
# Process the last, partial quanta
last = binascii.unhexlify('%010x' % acc)
if padchars == 0:
last = '' # No characters
elif padchars == 1:
last = last[:-1]
elif padchars == 3:
last = last[:-2]
elif padchars == 4:
last = last[:-3]
elif padchars == 6:
last = last[:-4]
else:
raise TypeError('Incorrect padding')
parts.append(last)
return EMPTYSTRING.join(parts)
# RFC 3548, Base 16 Alphabet specifies uppercase, but hexlify() returns
# lowercase. The RFC also recommends against accepting input case
# insensitively.
def b16encode(s):
"""Encode a string using Base16.
s is the string to encode. The encoded string is returned.
"""
return binascii.hexlify(s).upper()
def b16decode(s, casefold=False):
"""Decode a Base16 encoded string.
s is the string to decode. Optional casefold is a flag specifying whether
a lowercase alphabet is acceptable as input. For security purposes, the
default is False.
The decoded string is returned. A TypeError is raised if s is
incorrectly padded or if there are non-alphabet characters present in the
string.
"""
if casefold:
s = s.upper()
if re.search('[^0-9A-F]', s):
raise TypeError('Non-base16 digit found')
return binascii.unhexlify(s)
# Legacy interface. This code could be cleaned up since I don't believe
# binascii has any line length limitations. It just doesn't seem worth it
# though.
MAXLINESIZE = 76 # Excluding the CRLF
MAXBINSIZE = (MAXLINESIZE//4)*3
def encode(input, output):
"""Encode a file."""
while True:
s = input.read(MAXBINSIZE)
if not s:
break
while len(s) < MAXBINSIZE:
ns = input.read(MAXBINSIZE-len(s))
if not ns:
break
s += ns
line = binascii.b2a_base64(s)
output.write(line)
def decode(input, output):
"""Decode a file."""
while True:
line = input.readline()
if not line:
break
s = binascii.a2b_base64(line)
output.write(s)
def encodestring(s):
"""Encode a string into multiple lines of base-64 data."""
pieces = []
for i in range(0, len(s), MAXBINSIZE):
chunk = s[i : i + MAXBINSIZE]
pieces.append(binascii.b2a_base64(chunk))
return "".join(pieces)
def decodestring(s):
"""Decode a string."""
return binascii.a2b_base64(s)
# Useable as a script...
def test():
"""Small test program"""
import sys, getopt
try:
opts, args = getopt.getopt(sys.argv[1:], 'deut')
except getopt.error, msg:
sys.stdout = sys.stderr
print msg
print """usage: %s [-d|-e|-u|-t] [file|-]
-d, -u: decode
-e: encode (default)
-t: encode and decode string 'Aladdin:open sesame'"""%sys.argv[0]
sys.exit(2)
func = encode
for o, a in opts:
if o == '-e': func = encode
if o == '-d': func = decode
if o == '-u': func = decode
if o == '-t': test1(); return
if args and args[0] != '-':
with open(args[0], 'rb') as f:
func(f, sys.stdout)
else:
func(sys.stdin, sys.stdout)
def test1():
s0 = "Aladdin:open sesame"
s1 = encodestring(s0)
s2 = decodestring(s1)
print s0, repr(s1), s2
if __name__ == '__main__':
test()
"""HTTP server base class.
Note: the class in this module doesn't implement any HTTP request; see
SimpleHTTPServer for simple implementations of GET, HEAD and POST
(including CGI scripts). It does, however, optionally implement HTTP/1.1
persistent connections, as of version 0.3.
Contents:
- BaseHTTPRequestHandler: HTTP request handler base class
- test: test function
XXX To do:
- log requests even later (to capture byte count)
- log user-agent header and other interesting goodies
- send error log to separate file
"""
# See also:
#
# HTTP Working Group T. Berners-Lee
# INTERNET-DRAFT R. T. Fielding
# <draft-ietf-http-v10-spec-00.txt> H. Frystyk Nielsen
# Expires September 8, 1995 March 8, 1995
#
# URL: http://www.ics.uci.edu/pub/ietf/http/draft-ietf-http-v10-spec-00.txt
#
# and
#
# Network Working Group R. Fielding
# Request for Comments: 2616 et al
# Obsoletes: 2068 June 1999
# Category: Standards Track
#
# URL: http://www.faqs.org/rfcs/rfc2616.html
# Log files
# ---------
#
# Here's a quote from the NCSA httpd docs about log file format.
#
# | The logfile format is as follows. Each line consists of:
# |
# | host rfc931 authuser [DD/Mon/YYYY:hh:mm:ss] "request" ddd bbbb
# |
# | host: Either the DNS name or the IP number of the remote client
# | rfc931: Any information returned by identd for this person,
# | - otherwise.
# | authuser: If user sent a userid for authentication, the user name,
# | - otherwise.
# | DD: Day
# | Mon: Month (calendar name)
# | YYYY: Year
# | hh: hour (24-hour format, the machine's timezone)
# | mm: minutes
# | ss: seconds
# | request: The first line of the HTTP request as sent by the client.
# | ddd: the status code returned by the server, - if not available.
# | bbbb: the total number of bytes sent,
# | *not including the HTTP/1.0 header*, - if not available
# |
# | You can determine the name of the file accessed through request.
#
# (Actually, the latter is only true if you know the server configuration
# at the time the request was made!)
__version__ = "0.3"
__all__ = ["HTTPServer", "BaseHTTPRequestHandler"]
import sys
import time
import socket # For gethostbyaddr()
from warnings import filterwarnings, catch_warnings
with catch_warnings():
if sys.py3kwarning:
filterwarnings("ignore", ".*mimetools has been removed",
DeprecationWarning)
import mimetools
import SocketServer
# Default error message template
DEFAULT_ERROR_MESSAGE = """\
<head>
<title>Error response</title>
</head>
<body>
<h1>Error response</h1>
<p>Error code %(code)d.
<p>Message: %(message)s.
<p>Error code explanation: %(code)s = %(explain)s.
</body>
"""
DEFAULT_ERROR_CONTENT_TYPE = "text/html"
def _quote_html(html):
return html.replace("&", "&").replace("<", "<").replace(">", ">")
class HTTPServer(SocketServer.TCPServer):
allow_reuse_address = 1 # Seems to make sense in testing environment
def server_bind(self):
"""Override server_bind to store the server name."""
SocketServer.TCPServer.server_bind(self)
host, port = self.socket.getsockname()[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port
class BaseHTTPRequestHandler(SocketServer.StreamRequestHandler):
"""HTTP request handler base class.
The following explanation of HTTP serves to guide you through the
code as well as to expose any misunderstandings I may have about
HTTP (so you don't need to read the code to figure out I'm wrong
:-).
HTTP (HyperText Transfer Protocol) is an extensible protocol on
top of a reliable stream transport (e.g. TCP/IP). The protocol
recognizes three parts to a request:
1. One line identifying the request type and path
2. An optional set of RFC-822-style headers
3. An optional data part
The headers and data are separated by a blank line.
The first line of the request has the form
<command> <path> <version>
where <command> is a (case-sensitive) keyword such as GET or POST,
<path> is a string containing path information for the request,
and <version> should be the string "HTTP/1.0" or "HTTP/1.1".
<path> is encoded using the URL encoding scheme (using %xx to signify
the ASCII character with hex code xx).
The specification specifies that lines are separated by CRLF but
for compatibility with the widest range of clients recommends
servers also handle LF. Similarly, whitespace in the request line
is treated sensibly (allowing multiple spaces between components
and allowing trailing whitespace).
Similarly, for output, lines ought to be separated by CRLF pairs
but most clients grok LF characters just fine.
If the first line of the request has the form
<command> <path>
(i.e. <version> is left out) then this is assumed to be an HTTP
0.9 request; this form has no optional headers and data part and
the reply consists of just the data.
The reply form of the HTTP 1.x protocol again has three parts:
1. One line giving the response code
2. An optional set of RFC-822-style headers
3. The data
Again, the headers and data are separated by a blank line.
The response code line has the form
<version> <responsecode> <responsestring>
where <version> is the protocol version ("HTTP/1.0" or "HTTP/1.1"),
<responsecode> is a 3-digit response code indicating success or
failure of the request, and <responsestring> is an optional
human-readable string explaining what the response code means.
This server parses the request and the headers, and then calls a
function specific to the request type (<command>). Specifically,
a request SPAM will be handled by a method do_SPAM(). If no
such method exists the server sends an error response to the
client. If it exists, it is called with no arguments:
do_SPAM()
Note that the request name is case sensitive (i.e. SPAM and spam
are different requests).
The various request details are stored in instance variables:
- client_address is the client IP address in the form (host,
port);
- command, path and version are the broken-down request line;
- headers is an instance of mimetools.Message (or a derived
class) containing the header information;
- rfile is a file object open for reading positioned at the
start of the optional input data part;
- wfile is a file object open for writing.
IT IS IMPORTANT TO ADHERE TO THE PROTOCOL FOR WRITING!
The first thing to be written must be the response line. Then
follow 0 or more header lines, then a blank line, and then the
actual data (if any). The meaning of the header lines depends on
the command executed by the server; in most cases, when data is
returned, there should be at least one header line of the form
Content-type: <type>/<subtype>
where <type> and <subtype> should be registered MIME types,
e.g. "text/html" or "text/plain".
"""
# The Python system version, truncated to its first component.
sys_version = "Python/" + sys.version.split()[0]
# The server software version. You may want to override this.
# The format is multiple whitespace-separated strings,
# where each string is of the form name[/version].
server_version = "BaseHTTP/" + __version__
# The default request version. This only affects responses up until
# the point where the request line is parsed, so it mainly decides what
# the client gets back when sending a malformed request line.
# Most web servers default to HTTP 0.9, i.e. don't send a status line.
default_request_version = "HTTP/0.9"
def parse_request(self):
"""Parse a request (internal).
The request should be stored in self.raw_requestline; the results
are in self.command, self.path, self.request_version and
self.headers.
Return True for success, False for failure; on failure, an
error is sent back.
"""
self.command = None # set in case of error on the first line
self.request_version = version = self.default_request_version
self.close_connection = 1
requestline = self.raw_requestline
requestline = requestline.rstrip('\r\n')
self.requestline = requestline
words = requestline.split()
if len(words) == 3:
command, path, version = words
if version[:5] != 'HTTP/':
self.send_error(400, "Bad request version (%r)" % version)
return False
try:
base_version_number = version.split('/', 1)[1]
version_number = base_version_number.split(".")
# RFC 2145 section 3.1 says there can be only one "." and
# - major and minor numbers MUST be treated as
# separate integers;
# - HTTP/2.4 is a lower version than HTTP/2.13, which in
# turn is lower than HTTP/12.3;
# - Leading zeros MUST be ignored by recipients.
if len(version_number) != 2:
raise ValueError
version_number = int(version_number[0]), int(version_number[1])
except (ValueError, IndexError):
self.send_error(400, "Bad request version (%r)" % version)
return False
if version_number >= (1, 1) and self.protocol_version >= "HTTP/1.1":
self.close_connection = 0
if version_number >= (2, 0):
self.send_error(505,
"Invalid HTTP Version (%s)" % base_version_number)
return False
elif len(words) == 2:
command, path = words
self.close_connection = 1
if command != 'GET':
self.send_error(400,
"Bad HTTP/0.9 request type (%r)" % command)
return False
elif not words:
return False
else:
self.send_error(400, "Bad request syntax (%r)" % requestline)
return False
self.command, self.path, self.request_version = command, path, version
# Examine the headers and look for a Connection directive
self.headers = self.MessageClass(self.rfile, 0)
conntype = self.headers.get('Connection', "")
if conntype.lower() == 'close':
self.close_connection = 1
elif (conntype.lower() == 'keep-alive' and
self.protocol_version >= "HTTP/1.1"):
self.close_connection = 0
return True
def handle_one_request(self):
"""Handle a single HTTP request.
You normally don't need to override this method; see the class
__doc__ string for information on how to handle specific HTTP
commands such as GET and POST.
"""
try:
self.raw_requestline = self.rfile.readline(65537)
if len(self.raw_requestline) > 65536:
self.requestline = ''
self.request_version = ''
self.command = ''
self.send_error(414)
return
if not self.raw_requestline:
self.close_connection = 1
return
if not self.parse_request():
# An error code has been sent, just exit
return
mname = 'do_' + self.command
if not hasattr(self, mname):
self.send_error(501, "Unsupported method (%r)" % self.command)
return
method = getattr(self, mname)
method()
self.wfile.flush() #actually send the response if not already done.
except socket.timeout, e:
#a read or a write timed out. Discard this connection
self.log_error("Request timed out: %r", e)
self.close_connection = 1
return
def handle(self):
"""Handle multiple requests if necessary."""
self.close_connection = 1
self.handle_one_request()
while not self.close_connection:
self.handle_one_request()
def send_error(self, code, message=None):
"""Send and log an error reply.
Arguments are the error code, and a detailed message.
The detailed message defaults to the short entry matching the
response code.
This sends an error response (so it must be called before any
output has been generated), logs the error, and finally sends
a piece of HTML explaining the error to the user.
"""
try:
short, long = self.responses[code]
except KeyError:
short, long = '???', '???'
if message is None:
message = short
explain = long
self.log_error("code %d, message %s", code, message)
self.send_response(code, message)
self.send_header('Connection', 'close')
# Message body is omitted for cases described in:
# - RFC7230: 3.3. 1xx, 204(No Content), 304(Not Modified)
# - RFC7231: 6.3.6. 205(Reset Content)
content = None
if code >= 200 and code not in (204, 205, 304):
# HTML encode to prevent Cross Site Scripting attacks
# (see bug #1100201)
content = (self.error_message_format % {
'code': code,
'message': _quote_html(message),
'explain': explain
})
self.send_header("Content-Type", self.error_content_type)
self.end_headers()
if self.command != 'HEAD' and content:
self.wfile.write(content)
error_message_format = DEFAULT_ERROR_MESSAGE
error_content_type = DEFAULT_ERROR_CONTENT_TYPE
def send_response(self, code, message=None):
"""Send the response header and log the response code.
Also send two standard headers with the server software
version and the current date.
"""
self.log_request(code)
if message is None:
if code in self.responses:
message = self.responses[code][0]
else:
message = ''
if self.request_version != 'HTTP/0.9':
self.wfile.write("%s %d %s\r\n" %
(self.protocol_version, code, message))
# print (self.protocol_version, code, message)
self.send_header('Server', self.version_string())
self.send_header('Date', self.date_time_string())
def send_header(self, keyword, value):
"""Send a MIME header."""
if self.request_version != 'HTTP/0.9':
self.wfile.write("%s: %s\r\n" % (keyword, value))
if keyword.lower() == 'connection':
if value.lower() == 'close':
self.close_connection = 1
elif value.lower() == 'keep-alive':
self.close_connection = 0
def end_headers(self):
"""Send the blank line ending the MIME headers."""
if self.request_version != 'HTTP/0.9':
self.wfile.write("\r\n")
def log_request(self, code='-', size='-'):
"""Log an accepted request.
This is called by send_response().
"""
self.log_message('"%s" %s %s',
self.requestline, str(code), str(size))
def log_error(self, format, *args):
"""Log an error.
This is called when a request cannot be fulfilled. By
default it passes the message on to log_message().
Arguments are the same as for log_message().
XXX This should go to the separate error log.
"""
self.log_message(format, *args)
def log_message(self, format, *args):
"""Log an arbitrary message.
This is used by all other logging functions. Override
it if you have specific logging wishes.
The first argument, FORMAT, is a format string for the
message to be logged. If the format string contains
any % escapes requiring parameters, they should be
specified as subsequent arguments (it's just like
printf!).
The client ip address and current date/time are prefixed to every
message.
"""
sys.stderr.write("%s - - [%s] %s\n" %
(self.client_address[0],
self.log_date_time_string(),
format%args))
def version_string(self):
"""Return the server software version string."""
return self.server_version + ' ' + self.sys_version
def date_time_string(self, timestamp=None):
"""Return the current date and time formatted for a message header."""
if timestamp is None:
timestamp = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(timestamp)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
self.weekdayname[wd],
day, self.monthname[month], year,
hh, mm, ss)
return s
def log_date_time_string(self):
"""Return the current time formatted for logging."""
now = time.time()
year, month, day, hh, mm, ss, x, y, z = time.localtime(now)
s = "%02d/%3s/%04d %02d:%02d:%02d" % (
day, self.monthname[month], year, hh, mm, ss)
return s
weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
def address_string(self):
"""Return the client address formatted for logging.
This version looks up the full hostname using gethostbyaddr(),
and tries to find a name that contains at least one dot.
"""
host, port = self.client_address[:2]
return socket.getfqdn(host)
# Essentially static class variables
# The version of the HTTP protocol we support.
# Set this to HTTP/1.1 to enable automatic keepalive
protocol_version = "HTTP/1.0"
# The Message-like class used to parse headers
MessageClass = mimetools.Message
# Table mapping response codes to messages; entries have the
# form {code: (shortmessage, longmessage)}.
# See RFC 2616.
responses = {
100: ('Continue', 'Request received, please continue'),
101: ('Switching Protocols',
'Switching to new protocol; obey Upgrade header'),
200: ('OK', 'Request fulfilled, document follows'),
201: ('Created', 'Document created, URL follows'),
202: ('Accepted',
'Request accepted, processing continues off-line'),
203: ('Non-Authoritative Information', 'Request fulfilled from cache'),
204: ('No Content', 'Request fulfilled, nothing follows'),
205: ('Reset Content', 'Clear input form for further input.'),
206: ('Partial Content', 'Partial content follows.'),
300: ('Multiple Choices',
'Object has several resources -- see URI list'),
301: ('Moved Permanently', 'Object moved permanently -- see URI list'),
302: ('Found', 'Object moved temporarily -- see URI list'),
303: ('See Other', 'Object moved -- see Method and URL list'),
304: ('Not Modified',
'Document has not changed since given time'),
305: ('Use Proxy',
'You must use proxy specified in Location to access this '
'resource.'),
307: ('Temporary Redirect',
'Object moved temporarily -- see URI list'),
400: ('Bad Request',
'Bad request syntax or unsupported method'),
401: ('Unauthorized',
'No permission -- see authorization schemes'),
402: ('Payment Required',
'No payment -- see charging schemes'),
403: ('Forbidden',
'Request forbidden -- authorization will not help'),
404: ('Not Found', 'Nothing matches the given URI'),
405: ('Method Not Allowed',
'Specified method is invalid for this resource.'),
406: ('Not Acceptable', 'URI not available in preferred format.'),
407: ('Proxy Authentication Required', 'You must authenticate with '
'this proxy before proceeding.'),
408: ('Request Timeout', 'Request timed out; try again later.'),
409: ('Conflict', 'Request conflict.'),
410: ('Gone',
'URI no longer exists and has been permanently removed.'),
411: ('Length Required', 'Client must specify Content-Length.'),
412: ('Precondition Failed', 'Precondition in headers is false.'),
413: ('Request Entity Too Large', 'Entity is too large.'),
414: ('Request-URI Too Long', 'URI is too long.'),
415: ('Unsupported Media Type', 'Entity body in unsupported format.'),
416: ('Requested Range Not Satisfiable',
'Cannot satisfy request range.'),
417: ('Expectation Failed',
'Expect condition could not be satisfied.'),
500: ('Internal Server Error', 'Server got itself in trouble'),
501: ('Not Implemented',
'Server does not support this operation'),
502: ('Bad Gateway', 'Invalid responses from another server/proxy.'),
503: ('Service Unavailable',
'The server cannot process the request due to a high load'),
504: ('Gateway Timeout',
'The gateway server did not receive a timely response'),
505: ('HTTP Version Not Supported', 'Cannot fulfill request.'),
}
def test(HandlerClass = BaseHTTPRequestHandler,
ServerClass = HTTPServer, protocol="HTTP/1.0"):
"""Test the HTTP request handler class.
This runs an HTTP server on port 8000 (or the first command line
argument).
"""
if sys.argv[1:]:
port = int(sys.argv[1])
else:
port = 8000
server_address = ('', port)
HandlerClass.protocol_version = protocol
httpd = ServerClass(server_address, HandlerClass)
sa = httpd.socket.getsockname()
print "Serving HTTP on", sa[0], "port", sa[1], "..."
httpd.serve_forever()
if __name__ == '__main__':
test()
"""Bastionification utility.
A bastion (for another object -- the 'original') is an object that has
the same methods as the original but does not give access to its
instance variables. Bastions have a number of uses, but the most
obvious one is to provide code executing in restricted mode with a
safe interface to an object implemented in unrestricted mode.
The bastionification routine has an optional second argument which is
a filter function. Only those methods for which the filter method
(called with the method name as argument) returns true are accessible.
The default filter method returns true unless the method name begins
with an underscore.
There are a number of possible implementations of bastions. We use a
'lazy' approach where the bastion's __getattr__() discipline does all
the work for a particular method the first time it is used. This is
usually fastest, especially if the user doesn't call all available
methods. The retrieved methods are stored as instance variables of
the bastion, so the overhead is only occurred on the first use of each
method.
Detail: the bastion class has a __repr__() discipline which includes
the repr() of the original object. This is precomputed when the
bastion is created.
"""
from warnings import warnpy3k
warnpy3k("the Bastion module has been removed in Python 3.0", stacklevel=2)
del warnpy3k
__all__ = ["BastionClass", "Bastion"]
from types import MethodType
class BastionClass:
"""Helper class used by the Bastion() function.
You could subclass this and pass the subclass as the bastionclass
argument to the Bastion() function, as long as the constructor has
the same signature (a get() function and a name for the object).
"""
def __init__(self, get, name):
"""Constructor.
Arguments:
get - a function that gets the attribute value (by name)
name - a human-readable name for the original object
(suggestion: use repr(object))
"""
self._get_ = get
self._name_ = name
def __repr__(self):
"""Return a representation string.
This includes the name passed in to the constructor, so that
if you print the bastion during debugging, at least you have
some idea of what it is.
"""
return "<Bastion for %s>" % self._name_
def __getattr__(self, name):
"""Get an as-yet undefined attribute value.
This calls the get() function that was passed to the
constructor. The result is stored as an instance variable so
that the next time the same attribute is requested,
__getattr__() won't be invoked.
If the get() function raises an exception, this is simply
passed on -- exceptions are not cached.
"""
attribute = self._get_(name)
self.__dict__[name] = attribute
return attribute
def Bastion(object, filter = lambda name: name[:1] != '_',
name=None, bastionclass=BastionClass):
"""Create a bastion for an object, using an optional filter.
See the Bastion module's documentation for background.
Arguments:
object - the original object
filter - a predicate that decides whether a function name is OK;
by default all names are OK that don't start with '_'
name - the name of the object; default repr(object)
bastionclass - class used to create the bastion; default BastionClass
"""
raise RuntimeError, "This code is not secure in Python 2.2 and later"
# Note: we define *two* ad-hoc functions here, get1 and get2.
# Both are intended to be called in the same way: get(name).
# It is clear that the real work (getting the attribute
# from the object and calling the filter) is done in get1.
# Why can't we pass get1 to the bastion? Because the user
# would be able to override the filter argument! With get2,
# overriding the default argument is no security loophole:
# all it does is call it.
# Also notice that we can't place the object and filter as
# instance variables on the bastion object itself, since
# the user has full access to all instance variables!
def get1(name, object=object, filter=filter):
"""Internal function for Bastion(). See source comments."""
if filter(name):
attribute = getattr(object, name)
if type(attribute) == MethodType:
return attribute
raise AttributeError, name
def get2(name, get1=get1):
"""Internal function for Bastion(). See source comments."""
return get1(name)
if name is None:
name = repr(object)
return bastionclass(get2, name)
def _test():
"""Test the Bastion() function."""
class Original:
def __init__(self):
self.sum = 0
def add(self, n):
self._add(n)
def _add(self, n):
self.sum = self.sum + n
def total(self):
return self.sum
o = Original()
b = Bastion(o)
testcode = """if 1:
b.add(81)
b.add(18)
print "b.total() =", b.total()
try:
print "b.sum =", b.sum,
except:
print "inaccessible"
else:
print "accessible"
try:
print "b._add =", b._add,
except:
print "inaccessible"
else:
print "accessible"
try:
print "b._get_.func_defaults =", map(type, b._get_.func_defaults),
except:
print "inaccessible"
else:
print "accessible"
\n"""
exec testcode
print '='*20, "Using rexec:", '='*20
import rexec
r = rexec.RExec()
m = r.add_module('__main__')
m.b = b
r.r_exec(testcode)
if __name__ == '__main__':
_test()
"""Debugger basics"""
import fnmatch
import sys
import os
import types
__all__ = ["BdbQuit","Bdb","Breakpoint"]
class BdbQuit(Exception):
"""Exception to give up completely"""
class Bdb:
"""Generic Python debugger base class.
This class takes care of details of the trace facility;
a derived class should implement user interaction.
The standard debugger class (pdb.Pdb) is an example.
"""
def __init__(self, skip=None):
self.skip = set(skip) if skip else None
self.breaks = {}
self.fncache = {}
self.frame_returning = None
def canonic(self, filename):
if filename == "<" + filename[1:-1] + ">":
return filename
canonic = self.fncache.get(filename)
if not canonic:
canonic = os.path.abspath(filename)
canonic = os.path.normcase(canonic)
self.fncache[filename] = canonic
return canonic
def reset(self):
import linecache
linecache.checkcache()
self.botframe = None
self._set_stopinfo(None, None)
def trace_dispatch(self, frame, event, arg):
if self.quitting:
return # None
if event == 'line':
return self.dispatch_line(frame)
if event == 'call':
return self.dispatch_call(frame, arg)
if event == 'return':
return self.dispatch_return(frame, arg)
if event == 'exception':
return self.dispatch_exception(frame, arg)
if event == 'c_call':
return self.trace_dispatch
if event == 'c_exception':
return self.trace_dispatch
if event == 'c_return':
return self.trace_dispatch
print 'bdb.Bdb.dispatch: unknown debugging event:', repr(event)
return self.trace_dispatch
def dispatch_line(self, frame):
if self.stop_here(frame) or self.break_here(frame):
self.user_line(frame)
if self.quitting: raise BdbQuit
return self.trace_dispatch
def dispatch_call(self, frame, arg):
# XXX 'arg' is no longer used
if self.botframe is None:
# First call of dispatch since reset()
self.botframe = frame.f_back # (CT) Note that this may also be None!
return self.trace_dispatch
if not (self.stop_here(frame) or self.break_anywhere(frame)):
# No need to trace this function
return # None
self.user_call(frame, arg)
if self.quitting: raise BdbQuit
return self.trace_dispatch
def dispatch_return(self, frame, arg):
if self.stop_here(frame) or frame == self.returnframe:
try:
self.frame_returning = frame
self.user_return(frame, arg)
finally:
self.frame_returning = None
if self.quitting: raise BdbQuit
return self.trace_dispatch
def dispatch_exception(self, frame, arg):
if self.stop_here(frame):
self.user_exception(frame, arg)
if self.quitting: raise BdbQuit
return self.trace_dispatch
# Normally derived classes don't override the following
# methods, but they may if they want to redefine the
# definition of stopping and breakpoints.
def is_skipped_module(self, module_name):
for pattern in self.skip:
if fnmatch.fnmatch(module_name, pattern):
return True
return False
def stop_here(self, frame):
# (CT) stopframe may now also be None, see dispatch_call.
# (CT) the former test for None is therefore removed from here.
if self.skip and \
self.is_skipped_module(frame.f_globals.get('__name__')):
return False
if frame is self.stopframe:
if self.stoplineno == -1:
return False
return frame.f_lineno >= self.stoplineno
while frame is not None and frame is not self.stopframe:
if frame is self.botframe:
return True
frame = frame.f_back
return False
def break_here(self, frame):
filename = self.canonic(frame.f_code.co_filename)
if not filename in self.breaks:
return False
lineno = frame.f_lineno
if not lineno in self.breaks[filename]:
# The line itself has no breakpoint, but maybe the line is the
# first line of a function with breakpoint set by function name.
lineno = frame.f_code.co_firstlineno
if not lineno in self.breaks[filename]:
return False
# flag says ok to delete temp. bp
(bp, flag) = effective(filename, lineno, frame)
if bp:
self.currentbp = bp.number
if (flag and bp.temporary):
self.do_clear(str(bp.number))
return True
else:
return False
def do_clear(self, arg):
raise NotImplementedError, "subclass of bdb must implement do_clear()"
def break_anywhere(self, frame):
return self.canonic(frame.f_code.co_filename) in self.breaks
# Derived classes should override the user_* methods
# to gain control.
def user_call(self, frame, argument_list):
"""This method is called when there is the remote possibility
that we ever need to stop in this function."""
pass
def user_line(self, frame):
"""This method is called when we stop or break at this line."""
pass
def user_return(self, frame, return_value):
"""This method is called when a return trap is set here."""
pass
def user_exception(self, frame, exc_info):
exc_type, exc_value, exc_traceback = exc_info
"""This method is called if an exception occurs,
but only if we are to stop at or just below this level."""
pass
def _set_stopinfo(self, stopframe, returnframe, stoplineno=0):
self.stopframe = stopframe
self.returnframe = returnframe
self.quitting = 0
# stoplineno >= 0 means: stop at line >= the stoplineno
# stoplineno -1 means: don't stop at all
self.stoplineno = stoplineno
# Derived classes and clients can call the following methods
# to affect the stepping state.
def set_until(self, frame): #the name "until" is borrowed from gdb
"""Stop when the line with the line no greater than the current one is
reached or when returning from current frame"""
self._set_stopinfo(frame, frame, frame.f_lineno+1)
def set_step(self):
"""Stop after one line of code."""
# Issue #13183: pdb skips frames after hitting a breakpoint and running
# step commands.
# Restore the trace function in the caller (that may not have been set
# for performance reasons) when returning from the current frame.
if self.frame_returning:
caller_frame = self.frame_returning.f_back
if caller_frame and not caller_frame.f_trace:
caller_frame.f_trace = self.trace_dispatch
self._set_stopinfo(None, None)
def set_next(self, frame):
"""Stop on the next line in or below the given frame."""
self._set_stopinfo(frame, None)
def set_return(self, frame):
"""Stop when returning from the given frame."""
self._set_stopinfo(frame.f_back, frame)
def set_trace(self, frame=None):
"""Start debugging from `frame`.
If frame is not specified, debugging starts from caller's frame.
"""
if frame is None:
frame = sys._getframe().f_back
self.reset()
while frame:
frame.f_trace = self.trace_dispatch
self.botframe = frame
frame = frame.f_back
self.set_step()
sys.settrace(self.trace_dispatch)
def set_continue(self):
# Don't stop except at breakpoints or when finished
self._set_stopinfo(self.botframe, None, -1)
if not self.breaks:
# no breakpoints; run without debugger overhead
sys.settrace(None)
frame = sys._getframe().f_back
while frame and frame is not self.botframe:
del frame.f_trace
frame = frame.f_back
def set_quit(self):
self.stopframe = self.botframe
self.returnframe = None
self.quitting = 1
sys.settrace(None)
# Derived classes and clients can call the following methods
# to manipulate breakpoints. These methods return an
# error message is something went wrong, None if all is well.
# Set_break prints out the breakpoint line and file:lineno.
# Call self.get_*break*() to see the breakpoints or better
# for bp in Breakpoint.bpbynumber: if bp: bp.bpprint().
def set_break(self, filename, lineno, temporary=0, cond = None,
funcname=None):
filename = self.canonic(filename)
import linecache # Import as late as possible
line = linecache.getline(filename, lineno)
if not line:
return 'Line %s:%d does not exist' % (filename,
lineno)
if not filename in self.breaks:
self.breaks[filename] = []
list = self.breaks[filename]
if not lineno in list:
list.append(lineno)
bp = Breakpoint(filename, lineno, temporary, cond, funcname)
def _prune_breaks(self, filename, lineno):
if (filename, lineno) not in Breakpoint.bplist:
self.breaks[filename].remove(lineno)
if not self.breaks[filename]:
del self.breaks[filename]
def clear_break(self, filename, lineno):
filename = self.canonic(filename)
if not filename in self.breaks:
return 'There are no breakpoints in %s' % filename
if lineno not in self.breaks[filename]:
return 'There is no breakpoint at %s:%d' % (filename,
lineno)
# If there's only one bp in the list for that file,line
# pair, then remove the breaks entry
for bp in Breakpoint.bplist[filename, lineno][:]:
bp.deleteMe()
self._prune_breaks(filename, lineno)
def clear_bpbynumber(self, arg):
try:
number = int(arg)
except:
return 'Non-numeric breakpoint number (%s)' % arg
try:
bp = Breakpoint.bpbynumber[number]
except IndexError:
return 'Breakpoint number (%d) out of range' % number
if not bp:
return 'Breakpoint (%d) already deleted' % number
bp.deleteMe()
self._prune_breaks(bp.file, bp.line)
def clear_all_file_breaks(self, filename):
filename = self.canonic(filename)
if not filename in self.breaks:
return 'There are no breakpoints in %s' % filename
for line in self.breaks[filename]:
blist = Breakpoint.bplist[filename, line]
for bp in blist:
bp.deleteMe()
del self.breaks[filename]
def clear_all_breaks(self):
if not self.breaks:
return 'There are no breakpoints'
for bp in Breakpoint.bpbynumber:
if bp:
bp.deleteMe()
self.breaks = {}
def get_break(self, filename, lineno):
filename = self.canonic(filename)
return filename in self.breaks and \
lineno in self.breaks[filename]
def get_breaks(self, filename, lineno):
filename = self.canonic(filename)
return filename in self.breaks and \
lineno in self.breaks[filename] and \
Breakpoint.bplist[filename, lineno] or []
def get_file_breaks(self, filename):
filename = self.canonic(filename)
if filename in self.breaks:
return self.breaks[filename]
else:
return []
def get_all_breaks(self):
return self.breaks
# Derived classes and clients can call the following method
# to get a data structure representing a stack trace.
def get_stack(self, f, t):
stack = []
if t and t.tb_frame is f:
t = t.tb_next
while f is not None:
stack.append((f, f.f_lineno))
if f is self.botframe:
break
f = f.f_back
stack.reverse()
i = max(0, len(stack) - 1)
while t is not None:
stack.append((t.tb_frame, t.tb_lineno))
t = t.tb_next
if f is None:
i = max(0, len(stack) - 1)
return stack, i
#
def format_stack_entry(self, frame_lineno, lprefix=': '):
import linecache, repr
frame, lineno = frame_lineno
filename = self.canonic(frame.f_code.co_filename)
s = '%s(%r)' % (filename, lineno)
if frame.f_code.co_name:
s = s + frame.f_code.co_name
else:
s = s + "<lambda>"
if '__args__' in frame.f_locals:
args = frame.f_locals['__args__']
else:
args = None
if args:
s = s + repr.repr(args)
else:
s = s + '()'
if '__return__' in frame.f_locals:
rv = frame.f_locals['__return__']
s = s + '->'
s = s + repr.repr(rv)
line = linecache.getline(filename, lineno, frame.f_globals)
if line: s = s + lprefix + line.strip()
return s
# The following two methods can be called by clients to use
# a debugger to debug a statement, given as a string.
def run(self, cmd, globals=None, locals=None):
if globals is None:
import __main__
globals = __main__.__dict__
if locals is None:
locals = globals
self.reset()
sys.settrace(self.trace_dispatch)
if not isinstance(cmd, types.CodeType):
cmd = cmd+'\n'
try:
exec cmd in globals, locals
except BdbQuit:
pass
finally:
self.quitting = 1
sys.settrace(None)
def runeval(self, expr, globals=None, locals=None):
if globals is None:
import __main__
globals = __main__.__dict__
if locals is None:
locals = globals
self.reset()
sys.settrace(self.trace_dispatch)
if not isinstance(expr, types.CodeType):
expr = expr+'\n'
try:
return eval(expr, globals, locals)
except BdbQuit:
pass
finally:
self.quitting = 1
sys.settrace(None)
def runctx(self, cmd, globals, locals):
# B/W compatibility
self.run(cmd, globals, locals)
# This method is more useful to debug a single function call.
def runcall(self, func, *args, **kwds):
self.reset()
sys.settrace(self.trace_dispatch)
res = None
try:
res = func(*args, **kwds)
except BdbQuit:
pass
finally:
self.quitting = 1
sys.settrace(None)
return res
def set_trace():
Bdb().set_trace()
class Breakpoint:
"""Breakpoint class
Implements temporary breakpoints, ignore counts, disabling and
(re)-enabling, and conditionals.
Breakpoints are indexed by number through bpbynumber and by
the file,line tuple using bplist. The former points to a
single instance of class Breakpoint. The latter points to a
list of such instances since there may be more than one
breakpoint per line.
"""
# XXX Keeping state in the class is a mistake -- this means
# you cannot have more than one active Bdb instance.
next = 1 # Next bp to be assigned
bplist = {} # indexed by (file, lineno) tuple
bpbynumber = [None] # Each entry is None or an instance of Bpt
# index 0 is unused, except for marking an
# effective break .... see effective()
def __init__(self, file, line, temporary=0, cond=None, funcname=None):
self.funcname = funcname
# Needed if funcname is not None.
self.func_first_executable_line = None
self.file = file # This better be in canonical form!
self.line = line
self.temporary = temporary
self.cond = cond
self.enabled = 1
self.ignore = 0
self.hits = 0
self.number = Breakpoint.next
Breakpoint.next = Breakpoint.next + 1
# Build the two lists
self.bpbynumber.append(self)
if (file, line) in self.bplist:
self.bplist[file, line].append(self)
else:
self.bplist[file, line] = [self]
def deleteMe(self):
index = (self.file, self.line)
self.bpbynumber[self.number] = None # No longer in list
self.bplist[index].remove(self)
if not self.bplist[index]:
# No more bp for this f:l combo
del self.bplist[index]
def enable(self):
self.enabled = 1
def disable(self):
self.enabled = 0
def bpprint(self, out=None):
if out is None:
out = sys.stdout
if self.temporary:
disp = 'del '
else:
disp = 'keep '
if self.enabled:
disp = disp + 'yes '
else:
disp = disp + 'no '
print >>out, '%-4dbreakpoint %s at %s:%d' % (self.number, disp,
self.file, self.line)
if self.cond:
print >>out, '\tstop only if %s' % (self.cond,)
if self.ignore:
print >>out, '\tignore next %d hits' % (self.ignore)
if (self.hits):
if (self.hits > 1): ss = 's'
else: ss = ''
print >>out, ('\tbreakpoint already hit %d time%s' %
(self.hits, ss))
# -----------end of Breakpoint class----------
def checkfuncname(b, frame):
"""Check whether we should break here because of `b.funcname`."""
if not b.funcname:
# Breakpoint was set via line number.
if b.line != frame.f_lineno:
# Breakpoint was set at a line with a def statement and the function
# defined is called: don't break.
return False
return True
# Breakpoint set via function name.
if frame.f_code.co_name != b.funcname:
# It's not a function call, but rather execution of def statement.
return False
# We are in the right frame.
if not b.func_first_executable_line:
# The function is entered for the 1st time.
b.func_first_executable_line = frame.f_lineno
if b.func_first_executable_line != frame.f_lineno:
# But we are not at the first line number: don't break.
return False
return True
# Determines if there is an effective (active) breakpoint at this
# line of code. Returns breakpoint number or 0 if none
def effective(file, line, frame):
"""Determine which breakpoint for this file:line is to be acted upon.
Called only if we know there is a bpt at this
location. Returns breakpoint that was triggered and a flag
that indicates if it is ok to delete a temporary bp.
"""
possibles = Breakpoint.bplist[file,line]
for i in range(0, len(possibles)):
b = possibles[i]
if b.enabled == 0:
continue
if not checkfuncname(b, frame):
continue
# Count every hit when bp is enabled
b.hits = b.hits + 1
if not b.cond:
# If unconditional, and ignoring,
# go on to next, else break
if b.ignore > 0:
b.ignore = b.ignore -1
continue
else:
# breakpoint and marker that's ok
# to delete if temporary
return (b,1)
else:
# Conditional bp.
# Ignore count applies only to those bpt hits where the
# condition evaluates to true.
try:
val = eval(b.cond, frame.f_globals,
frame.f_locals)
if val:
if b.ignore > 0:
b.ignore = b.ignore -1
# continue
else:
return (b,1)
# else:
# continue
except:
# if eval fails, most conservative
# thing is to stop on breakpoint
# regardless of ignore count.
# Don't delete temporary,
# as another hint to user.
return (b,0)
return (None, None)
# -------------------- testing --------------------
class Tdb(Bdb):
def user_call(self, frame, args):
name = frame.f_code.co_name
if not name: name = '???'
print '+++ call', name, args
def user_line(self, frame):
import linecache
name = frame.f_code.co_name
if not name: name = '???'
fn = self.canonic(frame.f_code.co_filename)
line = linecache.getline(fn, frame.f_lineno, frame.f_globals)
print '+++', fn, frame.f_lineno, name, ':', line.strip()
def user_return(self, frame, retval):
print '+++ return', retval
def user_exception(self, frame, exc_stuff):
print '+++ exception', exc_stuff
self.set_continue()
def foo(n):
print 'foo(', n, ')'
x = bar(n*10)
print 'bar returned', x
def bar(a):
print 'bar(', a, ')'
return a/2
def test():
t = Tdb()
t.run('import bdb; bdb.foo(10)')
# end
"""Macintosh binhex compression/decompression.
easy interface:
binhex(inputfilename, outputfilename)
hexbin(inputfilename, outputfilename)
"""
#
# Jack Jansen, CWI, August 1995.
#
# The module is supposed to be as compatible as possible. Especially the
# easy interface should work "as expected" on any platform.
# XXXX Note: currently, textfiles appear in mac-form on all platforms.
# We seem to lack a simple character-translate in python.
# (we should probably use ISO-Latin-1 on all but the mac platform).
# XXXX The simple routines are too simple: they expect to hold the complete
# files in-core. Should be fixed.
# XXXX It would be nice to handle AppleDouble format on unix
# (for servers serving macs).
# XXXX I don't understand what happens when you get 0x90 times the same byte on
# input. The resulting code (xx 90 90) would appear to be interpreted as an
# escaped *value* of 0x90. All coders I've seen appear to ignore this nicety...
#
import sys
import os
import struct
import binascii
__all__ = ["binhex","hexbin","Error"]
class Error(Exception):
pass
# States (what have we written)
_DID_HEADER = 0
_DID_DATA = 1
# Various constants
REASONABLY_LARGE=32768 # Minimal amount we pass the rle-coder
LINELEN=64
RUNCHAR=chr(0x90) # run-length introducer
#
# This code is no longer byte-order dependent
#
# Workarounds for non-mac machines.
try:
from Carbon.File import FSSpec, FInfo
from MacOS import openrf
def getfileinfo(name):
finfo = FSSpec(name).FSpGetFInfo()
dir, file = os.path.split(name)
# XXX Get resource/data sizes
fp = open(name, 'rb')
fp.seek(0, 2)
dlen = fp.tell()
fp = openrf(name, '*rb')
fp.seek(0, 2)
rlen = fp.tell()
return file, finfo, dlen, rlen
def openrsrc(name, *mode):
if not mode:
mode = '*rb'
else:
mode = '*' + mode[0]
return openrf(name, mode)
except ImportError:
#
# Glue code for non-macintosh usage
#
class FInfo:
def __init__(self):
self.Type = '????'
self.Creator = '????'
self.Flags = 0
def getfileinfo(name):
finfo = FInfo()
# Quick check for textfile
fp = open(name)
data = open(name).read(256)
for c in data:
if not c.isspace() and (c<' ' or ord(c) > 0x7f):
break
else:
finfo.Type = 'TEXT'
fp.seek(0, 2)
dsize = fp.tell()
fp.close()
dir, file = os.path.split(name)
file = file.replace(':', '-', 1)
return file, finfo, dsize, 0
class openrsrc:
def __init__(self, *args):
pass
def read(self, *args):
return ''
def write(self, *args):
pass
def close(self):
pass
class _Hqxcoderengine:
"""Write data to the coder in 3-byte chunks"""
def __init__(self, ofp):
self.ofp = ofp
self.data = ''
self.hqxdata = ''
self.linelen = LINELEN-1
def write(self, data):
self.data = self.data + data
datalen = len(self.data)
todo = (datalen//3)*3
data = self.data[:todo]
self.data = self.data[todo:]
if not data:
return
self.hqxdata = self.hqxdata + binascii.b2a_hqx(data)
self._flush(0)
def _flush(self, force):
first = 0
while first <= len(self.hqxdata)-self.linelen:
last = first + self.linelen
self.ofp.write(self.hqxdata[first:last]+'\n')
self.linelen = LINELEN
first = last
self.hqxdata = self.hqxdata[first:]
if force:
self.ofp.write(self.hqxdata + ':\n')
def close(self):
if self.data:
self.hqxdata = \
self.hqxdata + binascii.b2a_hqx(self.data)
self._flush(1)
self.ofp.close()
del self.ofp
class _Rlecoderengine:
"""Write data to the RLE-coder in suitably large chunks"""
def __init__(self, ofp):
self.ofp = ofp
self.data = ''
def write(self, data):
self.data = self.data + data
if len(self.data) < REASONABLY_LARGE:
return
rledata = binascii.rlecode_hqx(self.data)
self.ofp.write(rledata)
self.data = ''
def close(self):
if self.data:
rledata = binascii.rlecode_hqx(self.data)
self.ofp.write(rledata)
self.ofp.close()
del self.ofp
class BinHex:
def __init__(self, name_finfo_dlen_rlen, ofp):
name, finfo, dlen, rlen = name_finfo_dlen_rlen
if type(ofp) == type(''):
ofname = ofp
ofp = open(ofname, 'w')
ofp.write('(This file must be converted with BinHex 4.0)\n\n:')
hqxer = _Hqxcoderengine(ofp)
self.ofp = _Rlecoderengine(hqxer)
self.crc = 0
if finfo is None:
finfo = FInfo()
self.dlen = dlen
self.rlen = rlen
self._writeinfo(name, finfo)
self.state = _DID_HEADER
def _writeinfo(self, name, finfo):
nl = len(name)
if nl > 63:
raise Error, 'Filename too long'
d = chr(nl) + name + '\0'
d2 = finfo.Type + finfo.Creator
# Force all structs to be packed with big-endian
d3 = struct.pack('>h', finfo.Flags)
d4 = struct.pack('>ii', self.dlen, self.rlen)
info = d + d2 + d3 + d4
self._write(info)
self._writecrc()
def _write(self, data):
self.crc = binascii.crc_hqx(data, self.crc)
self.ofp.write(data)
def _writecrc(self):
# XXXX Should this be here??
# self.crc = binascii.crc_hqx('\0\0', self.crc)
if self.crc < 0:
fmt = '>h'
else:
fmt = '>H'
self.ofp.write(struct.pack(fmt, self.crc))
self.crc = 0
def write(self, data):
if self.state != _DID_HEADER:
raise Error, 'Writing data at the wrong time'
self.dlen = self.dlen - len(data)
self._write(data)
def close_data(self):
if self.dlen != 0:
raise Error, 'Incorrect data size, diff=%r' % (self.rlen,)
self._writecrc()
self.state = _DID_DATA
def write_rsrc(self, data):
if self.state < _DID_DATA:
self.close_data()
if self.state != _DID_DATA:
raise Error, 'Writing resource data at the wrong time'
self.rlen = self.rlen - len(data)
self._write(data)
def close(self):
if self.state is None:
return
try:
if self.state < _DID_DATA:
self.close_data()
if self.state != _DID_DATA:
raise Error, 'Close at the wrong time'
if self.rlen != 0:
raise Error, \
"Incorrect resource-datasize, diff=%r" % (self.rlen,)
self._writecrc()
finally:
self.state = None
ofp = self.ofp
del self.ofp
ofp.close()
def binhex(inp, out):
"""(infilename, outfilename) - Create binhex-encoded copy of a file"""
finfo = getfileinfo(inp)
ofp = BinHex(finfo, out)
ifp = open(inp, 'rb')
# XXXX Do textfile translation on non-mac systems
while 1:
d = ifp.read(128000)
if not d: break
ofp.write(d)
ofp.close_data()
ifp.close()
ifp = openrsrc(inp, 'rb')
while 1:
d = ifp.read(128000)
if not d: break
ofp.write_rsrc(d)
ofp.close()
ifp.close()
class _Hqxdecoderengine:
"""Read data via the decoder in 4-byte chunks"""
def __init__(self, ifp):
self.ifp = ifp
self.eof = 0
def read(self, totalwtd):
"""Read at least wtd bytes (or until EOF)"""
decdata = ''
wtd = totalwtd
#
# The loop here is convoluted, since we don't really now how
# much to decode: there may be newlines in the incoming data.
while wtd > 0:
if self.eof: return decdata
wtd = ((wtd+2)//3)*4
data = self.ifp.read(wtd)
#
# Next problem: there may not be a complete number of
# bytes in what we pass to a2b. Solve by yet another
# loop.
#
while 1:
try:
decdatacur, self.eof = \
binascii.a2b_hqx(data)
break
except binascii.Incomplete:
pass
newdata = self.ifp.read(1)
if not newdata:
raise Error, \
'Premature EOF on binhex file'
data = data + newdata
decdata = decdata + decdatacur
wtd = totalwtd - len(decdata)
if not decdata and not self.eof:
raise Error, 'Premature EOF on binhex file'
return decdata
def close(self):
self.ifp.close()
class _Rledecoderengine:
"""Read data via the RLE-coder"""
def __init__(self, ifp):
self.ifp = ifp
self.pre_buffer = ''
self.post_buffer = ''
self.eof = 0
def read(self, wtd):
if wtd > len(self.post_buffer):
self._fill(wtd-len(self.post_buffer))
rv = self.post_buffer[:wtd]
self.post_buffer = self.post_buffer[wtd:]
return rv
def _fill(self, wtd):
self.pre_buffer = self.pre_buffer + self.ifp.read(wtd+4)
if self.ifp.eof:
self.post_buffer = self.post_buffer + \
binascii.rledecode_hqx(self.pre_buffer)
self.pre_buffer = ''
return
#
# Obfuscated code ahead. We have to take care that we don't
# end up with an orphaned RUNCHAR later on. So, we keep a couple
# of bytes in the buffer, depending on what the end of
# the buffer looks like:
# '\220\0\220' - Keep 3 bytes: repeated \220 (escaped as \220\0)
# '?\220' - Keep 2 bytes: repeated something-else
# '\220\0' - Escaped \220: Keep 2 bytes.
# '?\220?' - Complete repeat sequence: decode all
# otherwise: keep 1 byte.
#
mark = len(self.pre_buffer)
if self.pre_buffer[-3:] == RUNCHAR + '\0' + RUNCHAR:
mark = mark - 3
elif self.pre_buffer[-1] == RUNCHAR:
mark = mark - 2
elif self.pre_buffer[-2:] == RUNCHAR + '\0':
mark = mark - 2
elif self.pre_buffer[-2] == RUNCHAR:
pass # Decode all
else:
mark = mark - 1
self.post_buffer = self.post_buffer + \
binascii.rledecode_hqx(self.pre_buffer[:mark])
self.pre_buffer = self.pre_buffer[mark:]
def close(self):
self.ifp.close()
class HexBin:
def __init__(self, ifp):
if type(ifp) == type(''):
ifp = open(ifp)
#
# Find initial colon.
#
while 1:
ch = ifp.read(1)
if not ch:
raise Error, "No binhex data found"
# Cater for \r\n terminated lines (which show up as \n\r, hence
# all lines start with \r)
if ch == '\r':
continue
if ch == ':':
break
if ch != '\n':
dummy = ifp.readline()
hqxifp = _Hqxdecoderengine(ifp)
self.ifp = _Rledecoderengine(hqxifp)
self.crc = 0
self._readheader()
def _read(self, len):
data = self.ifp.read(len)
self.crc = binascii.crc_hqx(data, self.crc)
return data
def _checkcrc(self):
filecrc = struct.unpack('>h', self.ifp.read(2))[0] & 0xffff
#self.crc = binascii.crc_hqx('\0\0', self.crc)
# XXXX Is this needed??
self.crc = self.crc & 0xffff
if filecrc != self.crc:
raise Error, 'CRC error, computed %x, read %x' \
%(self.crc, filecrc)
self.crc = 0
def _readheader(self):
len = self._read(1)
fname = self._read(ord(len))
rest = self._read(1+4+4+2+4+4)
self._checkcrc()
type = rest[1:5]
creator = rest[5:9]
flags = struct.unpack('>h', rest[9:11])[0]
self.dlen = struct.unpack('>l', rest[11:15])[0]
self.rlen = struct.unpack('>l', rest[15:19])[0]
self.FName = fname
self.FInfo = FInfo()
self.FInfo.Creator = creator
self.FInfo.Type = type
self.FInfo.Flags = flags
self.state = _DID_HEADER
def read(self, *n):
if self.state != _DID_HEADER:
raise Error, 'Read data at wrong time'
if n:
n = n[0]
n = min(n, self.dlen)
else:
n = self.dlen
rv = ''
while len(rv) < n:
rv = rv + self._read(n-len(rv))
self.dlen = self.dlen - n
return rv
def close_data(self):
if self.state != _DID_HEADER:
raise Error, 'close_data at wrong time'
if self.dlen:
dummy = self._read(self.dlen)
self._checkcrc()
self.state = _DID_DATA
def read_rsrc(self, *n):
if self.state == _DID_HEADER:
self.close_data()
if self.state != _DID_DATA:
raise Error, 'Read resource data at wrong time'
if n:
n = n[0]
n = min(n, self.rlen)
else:
n = self.rlen
self.rlen = self.rlen - n
return self._read(n)
def close(self):
if self.state is None:
return
try:
if self.rlen:
dummy = self.read_rsrc(self.rlen)
self._checkcrc()
finally:
self.state = None
self.ifp.close()
def hexbin(inp, out):
"""(infilename, outfilename) - Decode binhexed file"""
ifp = HexBin(inp)
finfo = ifp.FInfo
if not out:
out = ifp.FName
ofp = open(out, 'wb')
# XXXX Do translation on non-mac systems
while 1:
d = ifp.read(128000)
if not d: break
ofp.write(d)
ofp.close()
ifp.close_data()
d = ifp.read_rsrc(128000)
if d:
ofp = openrsrc(out, 'wb')
ofp.write(d)
while 1:
d = ifp.read_rsrc(128000)
if not d: break
ofp.write(d)
ofp.close()
ifp.close()
def _test():
fname = sys.argv[1]
binhex(fname, fname+'.hqx')
hexbin(fname+'.hqx', fname+'.viahqx')
#hexbin(fname, fname+'.unpacked')
sys.exit(1)
if __name__ == '__main__':
_test()
"""Bisection algorithms."""
def insort_right(a, x, lo=0, hi=None):
"""Insert item x in list a, and keep it sorted assuming a is sorted.
If x is already in a, insert it to the right of the rightmost x.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid]: hi = mid
else: lo = mid+1
a.insert(lo, x)
insort = insort_right # backward compatibility
def bisect_right(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e <= x, and all e in
a[i:] have e > x. So if x already appears in the list, a.insert(x) will
insert just after the rightmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if x < a[mid]: hi = mid
else: lo = mid+1
return lo
bisect = bisect_right # backward compatibility
def insort_left(a, x, lo=0, hi=None):
"""Insert item x in list a, and keep it sorted assuming a is sorted.
If x is already in a, insert it to the left of the leftmost x.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if a[mid] < x: lo = mid+1
else: hi = mid
a.insert(lo, x)
def bisect_left(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e < x, and all e in
a[i:] have e >= x. So if x already appears in the list, a.insert(x) will
insert just before the leftmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
"""
if lo < 0:
raise ValueError('lo must be non-negative')
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
if a[mid] < x: lo = mid+1
else: hi = mid
return lo
# Overwrite above definitions with a fast C implementation
try:
from _bisect import *
except ImportError:
pass
"""Calendar printing functions
Note when comparing these calendars to the ones printed by cal(1): By
default, these calendars have Monday as the first day of the week, and
Sunday as the last (the European convention). Use setfirstweekday() to
set the first day of the week (0=Monday, 6=Sunday)."""
import sys
import datetime
import locale as _locale
__all__ = ["IllegalMonthError", "IllegalWeekdayError", "setfirstweekday",
"firstweekday", "isleap", "leapdays", "weekday", "monthrange",
"monthcalendar", "prmonth", "month", "prcal", "calendar",
"timegm", "month_name", "month_abbr", "day_name", "day_abbr"]
# Exception raised for bad input (with string parameter for details)
error = ValueError
# Exceptions raised for bad input
class IllegalMonthError(ValueError):
def __init__(self, month):
self.month = month
def __str__(self):
return "bad month number %r; must be 1-12" % self.month
class IllegalWeekdayError(ValueError):
def __init__(self, weekday):
self.weekday = weekday
def __str__(self):
return "bad weekday number %r; must be 0 (Monday) to 6 (Sunday)" % self.weekday
# Constants for months referenced later
January = 1
February = 2
# Number of days per month (except for February in leap years)
mdays = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
# This module used to have hard-coded lists of day and month names, as
# English strings. The classes following emulate a read-only version of
# that, but supply localized names. Note that the values are computed
# fresh on each call, in case the user changes locale between calls.
class _localized_month:
_months = [datetime.date(2001, i+1, 1).strftime for i in range(12)]
_months.insert(0, lambda x: "")
def __init__(self, format):
self.format = format
def __getitem__(self, i):
funcs = self._months[i]
if isinstance(i, slice):
return [f(self.format) for f in funcs]
else:
return funcs(self.format)
def __len__(self):
return 13
class _localized_day:
# January 1, 2001, was a Monday.
_days = [datetime.date(2001, 1, i+1).strftime for i in range(7)]
def __init__(self, format):
self.format = format
def __getitem__(self, i):
funcs = self._days[i]
if isinstance(i, slice):
return [f(self.format) for f in funcs]
else:
return funcs(self.format)
def __len__(self):
return 7
# Full and abbreviated names of weekdays
day_name = _localized_day('%A')
day_abbr = _localized_day('%a')
# Full and abbreviated names of months (1-based arrays!!!)
month_name = _localized_month('%B')
month_abbr = _localized_month('%b')
# Constants for weekdays
(MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY, SUNDAY) = range(7)
def isleap(year):
"""Return True for leap years, False for non-leap years."""
return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
def leapdays(y1, y2):
"""Return number of leap years in range [y1, y2).
Assume y1 <= y2."""
y1 -= 1
y2 -= 1
return (y2//4 - y1//4) - (y2//100 - y1//100) + (y2//400 - y1//400)
def weekday(year, month, day):
"""Return weekday (0-6 ~ Mon-Sun) for year (1970-...), month (1-12),
day (1-31)."""
return datetime.date(year, month, day).weekday()
def monthrange(year, month):
"""Return weekday (0-6 ~ Mon-Sun) and number of days (28-31) for
year, month."""
if not 1 <= month <= 12:
raise IllegalMonthError(month)
day1 = weekday(year, month, 1)
ndays = mdays[month] + (month == February and isleap(year))
return day1, ndays
class Calendar(object):
"""
Base calendar class. This class doesn't do any formatting. It simply
provides data to subclasses.
"""
def __init__(self, firstweekday=0):
self.firstweekday = firstweekday # 0 = Monday, 6 = Sunday
def getfirstweekday(self):
return self._firstweekday % 7
def setfirstweekday(self, firstweekday):
self._firstweekday = firstweekday
firstweekday = property(getfirstweekday, setfirstweekday)
def iterweekdays(self):
"""
Return an iterator for one week of weekday numbers starting with the
configured first one.
"""
for i in range(self.firstweekday, self.firstweekday + 7):
yield i%7
def itermonthdates(self, year, month):
"""
Return an iterator for one month. The iterator will yield datetime.date
values and will always iterate through complete weeks, so it will yield
dates outside the specified month.
"""
date = datetime.date(year, month, 1)
# Go back to the beginning of the week
days = (date.weekday() - self.firstweekday) % 7
date -= datetime.timedelta(days=days)
oneday = datetime.timedelta(days=1)
while True:
yield date
try:
date += oneday
except OverflowError:
# Adding one day could fail after datetime.MAXYEAR
break
if date.month != month and date.weekday() == self.firstweekday:
break
def itermonthdays2(self, year, month):
"""
Like itermonthdates(), but will yield (day number, weekday number)
tuples. For days outside the specified month the day number is 0.
"""
for i, d in enumerate(self.itermonthdays(year, month), self.firstweekday):
yield d, i % 7
def itermonthdays(self, year, month):
"""
Like itermonthdates(), but will yield day numbers. For days outside
the specified month the day number is 0.
"""
day1, ndays = monthrange(year, month)
days_before = (day1 - self.firstweekday) % 7
for _ in range(days_before):
yield 0
for d in range(1, ndays + 1):
yield d
days_after = (self.firstweekday - day1 - ndays) % 7
for _ in range(days_after):
yield 0
def monthdatescalendar(self, year, month):
"""
Return a matrix (list of lists) representing a month's calendar.
Each row represents a week; week entries are datetime.date values.
"""
dates = list(self.itermonthdates(year, month))
return [ dates[i:i+7] for i in range(0, len(dates), 7) ]
def monthdays2calendar(self, year, month):
"""
Return a matrix representing a month's calendar.
Each row represents a week; week entries are
(day number, weekday number) tuples. Day numbers outside this month
are zero.
"""
days = list(self.itermonthdays2(year, month))
return [ days[i:i+7] for i in range(0, len(days), 7) ]
def monthdayscalendar(self, year, month):
"""
Return a matrix representing a month's calendar.
Each row represents a week; days outside this month are zero.
"""
days = list(self.itermonthdays(year, month))
return [ days[i:i+7] for i in range(0, len(days), 7) ]
def yeardatescalendar(self, year, width=3):
"""
Return the data for the specified year ready for formatting. The return
value is a list of month rows. Each month row contains up to width months.
Each month contains between 4 and 6 weeks and each week contains 1-7
days. Days are datetime.date objects.
"""
months = [
self.monthdatescalendar(year, i)
for i in range(January, January+12)
]
return [months[i:i+width] for i in range(0, len(months), width) ]
def yeardays2calendar(self, year, width=3):
"""
Return the data for the specified year ready for formatting (similar to
yeardatescalendar()). Entries in the week lists are
(day number, weekday number) tuples. Day numbers outside this month are
zero.
"""
months = [
self.monthdays2calendar(year, i)
for i in range(January, January+12)
]
return [months[i:i+width] for i in range(0, len(months), width) ]
def yeardayscalendar(self, year, width=3):
"""
Return the data for the specified year ready for formatting (similar to
yeardatescalendar()). Entries in the week lists are day numbers.
Day numbers outside this month are zero.
"""
months = [
self.monthdayscalendar(year, i)
for i in range(January, January+12)
]
return [months[i:i+width] for i in range(0, len(months), width) ]
class TextCalendar(Calendar):
"""
Subclass of Calendar that outputs a calendar as a simple plain text
similar to the UNIX program cal.
"""
def prweek(self, theweek, width):
"""
Print a single week (no newline).
"""
print self.formatweek(theweek, width),
def formatday(self, day, weekday, width):
"""
Returns a formatted day.
"""
if day == 0:
s = ''
else:
s = '%2i' % day # right-align single-digit days
return s.center(width)
def formatweek(self, theweek, width):
"""
Returns a single week in a string (no newline).
"""
return ' '.join(self.formatday(d, wd, width) for (d, wd) in theweek)
def formatweekday(self, day, width):
"""
Returns a formatted week day name.
"""
if width >= 9:
names = day_name
else:
names = day_abbr
return names[day][:width].center(width)
def formatweekheader(self, width):
"""
Return a header for a week.
"""
return ' '.join(self.formatweekday(i, width) for i in self.iterweekdays())
def formatmonthname(self, theyear, themonth, width, withyear=True):
"""
Return a formatted month name.
"""
s = month_name[themonth]
if withyear:
s = "%s %r" % (s, theyear)
return s.center(width)
def prmonth(self, theyear, themonth, w=0, l=0):
"""
Print a month's calendar.
"""
print self.formatmonth(theyear, themonth, w, l),
def formatmonth(self, theyear, themonth, w=0, l=0):
"""
Return a month's calendar string (multi-line).
"""
w = max(2, w)
l = max(1, l)
s = self.formatmonthname(theyear, themonth, 7 * (w + 1) - 1)
s = s.rstrip()
s += '\n' * l
s += self.formatweekheader(w).rstrip()
s += '\n' * l
for week in self.monthdays2calendar(theyear, themonth):
s += self.formatweek(week, w).rstrip()
s += '\n' * l
return s
def formatyear(self, theyear, w=2, l=1, c=6, m=3):
"""
Returns a year's calendar as a multi-line string.
"""
w = max(2, w)
l = max(1, l)
c = max(2, c)
colwidth = (w + 1) * 7 - 1
v = []
a = v.append
a(repr(theyear).center(colwidth*m+c*(m-1)).rstrip())
a('\n'*l)
header = self.formatweekheader(w)
for (i, row) in enumerate(self.yeardays2calendar(theyear, m)):
# months in this row
months = range(m*i+1, min(m*(i+1)+1, 13))
a('\n'*l)
names = (self.formatmonthname(theyear, k, colwidth, False)
for k in months)
a(formatstring(names, colwidth, c).rstrip())
a('\n'*l)
headers = (header for k in months)
a(formatstring(headers, colwidth, c).rstrip())
a('\n'*l)
# max number of weeks for this row
height = max(len(cal) for cal in row)
for j in range(height):
weeks = []
for cal in row:
if j >= len(cal):
weeks.append('')
else:
weeks.append(self.formatweek(cal[j], w))
a(formatstring(weeks, colwidth, c).rstrip())
a('\n' * l)
return ''.join(v)
def pryear(self, theyear, w=0, l=0, c=6, m=3):
"""Print a year's calendar."""
print self.formatyear(theyear, w, l, c, m)
class HTMLCalendar(Calendar):
"""
This calendar returns complete HTML pages.
"""
# CSS classes for the day <td>s
cssclasses = ["mon", "tue", "wed", "thu", "fri", "sat", "sun"]
def formatday(self, day, weekday):
"""
Return a day as a table cell.
"""
if day == 0:
return '<td class="noday"> </td>' # day outside month
else:
return '<td class="%s">%d</td>' % (self.cssclasses[weekday], day)
def formatweek(self, theweek):
"""
Return a complete week as a table row.
"""
s = ''.join(self.formatday(d, wd) for (d, wd) in theweek)
return '<tr>%s</tr>' % s
def formatweekday(self, day):
"""
Return a weekday name as a table header.
"""
return '<th class="%s">%s</th>' % (self.cssclasses[day], day_abbr[day])
def formatweekheader(self):
"""
Return a header for a week as a table row.
"""
s = ''.join(self.formatweekday(i) for i in self.iterweekdays())
return '<tr>%s</tr>' % s
def formatmonthname(self, theyear, themonth, withyear=True):
"""
Return a month name as a table row.
"""
if withyear:
s = '%s %s' % (month_name[themonth], theyear)
else:
s = '%s' % month_name[themonth]
return '<tr><th colspan="7" class="month">%s</th></tr>' % s
def formatmonth(self, theyear, themonth, withyear=True):
"""
Return a formatted month as a table.
"""
v = []
a = v.append
a('<table border="0" cellpadding="0" cellspacing="0" class="month">')
a('\n')
a(self.formatmonthname(theyear, themonth, withyear=withyear))
a('\n')
a(self.formatweekheader())
a('\n')
for week in self.monthdays2calendar(theyear, themonth):
a(self.formatweek(week))
a('\n')
a('</table>')
a('\n')
return ''.join(v)
def formatyear(self, theyear, width=3):
"""
Return a formatted year as a table of tables.
"""
v = []
a = v.append
width = max(width, 1)
a('<table border="0" cellpadding="0" cellspacing="0" class="year">')
a('\n')
a('<tr><th colspan="%d" class="year">%s</th></tr>' % (width, theyear))
for i in range(January, January+12, width):
# months in this row
months = range(i, min(i+width, 13))
a('<tr>')
for m in months:
a('<td>')
a(self.formatmonth(theyear, m, withyear=False))
a('</td>')
a('</tr>')
a('</table>')
return ''.join(v)
def formatyearpage(self, theyear, width=3, css='calendar.css', encoding=None):
"""
Return a formatted year as a complete HTML page.
"""
if encoding is None:
encoding = sys.getdefaultencoding()
v = []
a = v.append
a('<?xml version="1.0" encoding="%s"?>\n' % encoding)
a('<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">\n')
a('<html>\n')
a('<head>\n')
a('<meta http-equiv="Content-Type" content="text/html; charset=%s" />\n' % encoding)
if css is not None:
a('<link rel="stylesheet" type="text/css" href="%s" />\n' % css)
a('<title>Calendar for %d</title>\n' % theyear)
a('</head>\n')
a('<body>\n')
a(self.formatyear(theyear, width))
a('</body>\n')
a('</html>\n')
return ''.join(v).encode(encoding, "xmlcharrefreplace")
class TimeEncoding:
def __init__(self, locale):
self.locale = locale
def __enter__(self):
self.oldlocale = _locale.getlocale(_locale.LC_TIME)
_locale.setlocale(_locale.LC_TIME, self.locale)
return _locale.getlocale(_locale.LC_TIME)[1]
def __exit__(self, *args):
_locale.setlocale(_locale.LC_TIME, self.oldlocale)
class LocaleTextCalendar(TextCalendar):
"""
This class can be passed a locale name in the constructor and will return
month and weekday names in the specified locale. If this locale includes
an encoding all strings containing month and weekday names will be returned
as unicode.
"""
def __init__(self, firstweekday=0, locale=None):
TextCalendar.__init__(self, firstweekday)
if locale is None:
locale = _locale.getdefaultlocale()
self.locale = locale
def formatweekday(self, day, width):
with TimeEncoding(self.locale) as encoding:
if width >= 9:
names = day_name
else:
names = day_abbr
name = names[day]
if encoding is not None:
name = name.decode(encoding)
return name[:width].center(width)
def formatmonthname(self, theyear, themonth, width, withyear=True):
with TimeEncoding(self.locale) as encoding:
s = month_name[themonth]
if encoding is not None:
s = s.decode(encoding)
if withyear:
s = "%s %r" % (s, theyear)
return s.center(width)
class LocaleHTMLCalendar(HTMLCalendar):
"""
This class can be passed a locale name in the constructor and will return
month and weekday names in the specified locale. If this locale includes
an encoding all strings containing month and weekday names will be returned
as unicode.
"""
def __init__(self, firstweekday=0, locale=None):
HTMLCalendar.__init__(self, firstweekday)
if locale is None:
locale = _locale.getdefaultlocale()
self.locale = locale
def formatweekday(self, day):
with TimeEncoding(self.locale) as encoding:
s = day_abbr[day]
if encoding is not None:
s = s.decode(encoding)
return '<th class="%s">%s</th>' % (self.cssclasses[day], s)
def formatmonthname(self, theyear, themonth, withyear=True):
with TimeEncoding(self.locale) as encoding:
s = month_name[themonth]
if encoding is not None:
s = s.decode(encoding)
if withyear:
s = '%s %s' % (s, theyear)
return '<tr><th colspan="7" class="month">%s</th></tr>' % s
# Support for old module level interface
c = TextCalendar()
firstweekday = c.getfirstweekday
def setfirstweekday(firstweekday):
try:
firstweekday.__index__
except AttributeError:
raise IllegalWeekdayError(firstweekday)
if not MONDAY <= firstweekday <= SUNDAY:
raise IllegalWeekdayError(firstweekday)
c.firstweekday = firstweekday
monthcalendar = c.monthdayscalendar
prweek = c.prweek
week = c.formatweek
weekheader = c.formatweekheader
prmonth = c.prmonth
month = c.formatmonth
calendar = c.formatyear
prcal = c.pryear
# Spacing of month columns for multi-column year calendar
_colwidth = 7*3 - 1 # Amount printed by prweek()
_spacing = 6 # Number of spaces between columns
def format(cols, colwidth=_colwidth, spacing=_spacing):
"""Prints multi-column formatting for year calendars"""
print formatstring(cols, colwidth, spacing)
def formatstring(cols, colwidth=_colwidth, spacing=_spacing):
"""Returns a string formatted from n strings, centered within n columns."""
spacing *= ' '
return spacing.join(c.center(colwidth) for c in cols)
EPOCH = 1970
_EPOCH_ORD = datetime.date(EPOCH, 1, 1).toordinal()
def timegm(tuple):
"""Unrelated but handy function to calculate Unix timestamp from GMT."""
year, month, day, hour, minute, second = tuple[:6]
days = datetime.date(year, month, 1).toordinal() - _EPOCH_ORD + day - 1
hours = days*24 + hour
minutes = hours*60 + minute
seconds = minutes*60 + second
return seconds
def main(args):
import optparse
parser = optparse.OptionParser(usage="usage: %prog [options] [year [month]]")
parser.add_option(
"-w", "--width",
dest="width", type="int", default=2,
help="width of date column (default 2, text only)"
)
parser.add_option(
"-l", "--lines",
dest="lines", type="int", default=1,
help="number of lines for each week (default 1, text only)"
)
parser.add_option(
"-s", "--spacing",
dest="spacing", type="int", default=6,
help="spacing between months (default 6, text only)"
)
parser.add_option(
"-m", "--months",
dest="months", type="int", default=3,
help="months per row (default 3, text only)"
)
parser.add_option(
"-c", "--css",
dest="css", default="calendar.css",
help="CSS to use for page (html only)"
)
parser.add_option(
"-L", "--locale",
dest="locale", default=None,
help="locale to be used from month and weekday names"
)
parser.add_option(
"-e", "--encoding",
dest="encoding", default=None,
help="Encoding to use for output"
)
parser.add_option(
"-t", "--type",
dest="type", default="text",
choices=("text", "html"),
help="output type (text or html)"
)
(options, args) = parser.parse_args(args)
if options.locale and not options.encoding:
parser.error("if --locale is specified --encoding is required")
sys.exit(1)
locale = options.locale, options.encoding
if options.type == "html":
if options.locale:
cal = LocaleHTMLCalendar(locale=locale)
else:
cal = HTMLCalendar()
encoding = options.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
optdict = dict(encoding=encoding, css=options.css)
if len(args) == 1:
print cal.formatyearpage(datetime.date.today().year, **optdict)
elif len(args) == 2:
print cal.formatyearpage(int(args[1]), **optdict)
else:
parser.error("incorrect number of arguments")
sys.exit(1)
else:
if options.locale:
cal = LocaleTextCalendar(locale=locale)
else:
cal = TextCalendar()
optdict = dict(w=options.width, l=options.lines)
if len(args) != 3:
optdict["c"] = options.spacing
optdict["m"] = options.months
if len(args) == 1:
result = cal.formatyear(datetime.date.today().year, **optdict)
elif len(args) == 2:
result = cal.formatyear(int(args[1]), **optdict)
elif len(args) == 3:
result = cal.formatmonth(int(args[1]), int(args[2]), **optdict)
else:
parser.error("incorrect number of arguments")
sys.exit(1)
if options.encoding:
result = result.encode(options.encoding)
print result
if __name__ == "__main__":
main(sys.argv)
#! /usr/local/bin/python
# NOTE: the above "/usr/local/bin/python" is NOT a mistake. It is
# intentionally NOT "/usr/bin/env python". On many systems
# (e.g. Solaris), /usr/local/bin is not in $PATH as passed to CGI
# scripts, and /usr/local/bin is the default directory where Python is
# installed, so /usr/bin/env would be unable to find python. Granted,
# binary installations by Linux vendors often install Python in
# /usr/bin. So let those vendors patch cgi.py to match their choice
# of installation.
"""Support module for CGI (Common Gateway Interface) scripts.
This module defines a number of utilities for use by CGI scripts
written in Python.
"""
# XXX Perhaps there should be a slimmed version that doesn't contain
# all those backwards compatible and debugging classes and functions?
# History
# -------
#
# Michael McLay started this module. Steve Majewski changed the
# interface to SvFormContentDict and FormContentDict. The multipart
# parsing was inspired by code submitted by Andreas Paepcke. Guido van
# Rossum rewrote, reformatted and documented the module and is currently
# responsible for its maintenance.
#
__version__ = "2.6"
# Imports
# =======
from operator import attrgetter
import sys
import os
import UserDict
import urlparse
from warnings import filterwarnings, catch_warnings, warn
with catch_warnings():
if sys.py3kwarning:
filterwarnings("ignore", ".*mimetools has been removed",
DeprecationWarning)
filterwarnings("ignore", ".*rfc822 has been removed",
DeprecationWarning)
import mimetools
import rfc822
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
__all__ = ["MiniFieldStorage", "FieldStorage", "FormContentDict",
"SvFormContentDict", "InterpFormContentDict", "FormContent",
"parse", "parse_qs", "parse_qsl", "parse_multipart",
"parse_header", "print_exception", "print_environ",
"print_form", "print_directory", "print_arguments",
"print_environ_usage", "escape"]
# Logging support
# ===============
logfile = "" # Filename to log to, if not empty
logfp = None # File object to log to, if not None
def initlog(*allargs):
"""Write a log message, if there is a log file.
Even though this function is called initlog(), you should always
use log(); log is a variable that is set either to initlog
(initially), to dolog (once the log file has been opened), or to
nolog (when logging is disabled).
The first argument is a format string; the remaining arguments (if
any) are arguments to the % operator, so e.g.
log("%s: %s", "a", "b")
will write "a: b" to the log file, followed by a newline.
If the global logfp is not None, it should be a file object to
which log data is written.
If the global logfp is None, the global logfile may be a string
giving a filename to open, in append mode. This file should be
world writable!!! If the file can't be opened, logging is
silently disabled (since there is no safe place where we could
send an error message).
"""
global logfp, log
if logfile and not logfp:
try:
logfp = open(logfile, "a")
except IOError:
pass
if not logfp:
log = nolog
else:
log = dolog
log(*allargs)
def dolog(fmt, *args):
"""Write a log message to the log file. See initlog() for docs."""
logfp.write(fmt%args + "\n")
def nolog(*allargs):
"""Dummy function, assigned to log when logging is disabled."""
pass
log = initlog # The current logging function
# Parsing functions
# =================
# Maximum input we will accept when REQUEST_METHOD is POST
# 0 ==> unlimited input
maxlen = 0
def parse(fp=None, environ=os.environ, keep_blank_values=0, strict_parsing=0):
"""Parse a query in the environment or from a file (default stdin)
Arguments, all optional:
fp : file pointer; default: sys.stdin
environ : environment dictionary; default: os.environ
keep_blank_values: flag indicating whether blank values in
percent-encoded forms should be treated as blank strings.
A true value indicates that blanks should be retained as
blank strings. The default false value indicates that
blank values are to be ignored and treated as if they were
not included.
strict_parsing: flag indicating what to do with parsing errors.
If false (the default), errors are silently ignored.
If true, errors raise a ValueError exception.
"""
if fp is None:
fp = sys.stdin
if not 'REQUEST_METHOD' in environ:
environ['REQUEST_METHOD'] = 'GET' # For testing stand-alone
if environ['REQUEST_METHOD'] == 'POST':
ctype, pdict = parse_header(environ['CONTENT_TYPE'])
if ctype == 'multipart/form-data':
return parse_multipart(fp, pdict)
elif ctype == 'application/x-www-form-urlencoded':
clength = int(environ['CONTENT_LENGTH'])
if maxlen and clength > maxlen:
raise ValueError, 'Maximum content length exceeded'
qs = fp.read(clength)
else:
qs = '' # Unknown content-type
if 'QUERY_STRING' in environ:
if qs: qs = qs + '&'
qs = qs + environ['QUERY_STRING']
elif sys.argv[1:]:
if qs: qs = qs + '&'
qs = qs + sys.argv[1]
environ['QUERY_STRING'] = qs # XXX Shouldn't, really
elif 'QUERY_STRING' in environ:
qs = environ['QUERY_STRING']
else:
if sys.argv[1:]:
qs = sys.argv[1]
else:
qs = ""
environ['QUERY_STRING'] = qs # XXX Shouldn't, really
return urlparse.parse_qs(qs, keep_blank_values, strict_parsing)
# parse query string function called from urlparse,
# this is done in order to maintain backward compatibility.
def parse_qs(qs, keep_blank_values=0, strict_parsing=0):
"""Parse a query given as a string argument."""
warn("cgi.parse_qs is deprecated, use urlparse.parse_qs instead",
PendingDeprecationWarning, 2)
return urlparse.parse_qs(qs, keep_blank_values, strict_parsing)
def parse_qsl(qs, keep_blank_values=0, strict_parsing=0, max_num_fields=None):
"""Parse a query given as a string argument."""
warn("cgi.parse_qsl is deprecated, use urlparse.parse_qsl instead",
PendingDeprecationWarning, 2)
return urlparse.parse_qsl(qs, keep_blank_values, strict_parsing,
max_num_fields)
def parse_multipart(fp, pdict):
"""Parse multipart input.
Arguments:
fp : input file
pdict: dictionary containing other parameters of content-type header
Returns a dictionary just like parse_qs(): keys are the field names, each
value is a list of values for that field. This is easy to use but not
much good if you are expecting megabytes to be uploaded -- in that case,
use the FieldStorage class instead which is much more flexible. Note
that content-type is the raw, unparsed contents of the content-type
header.
XXX This does not parse nested multipart parts -- use FieldStorage for
that.
XXX This should really be subsumed by FieldStorage altogether -- no
point in having two implementations of the same parsing algorithm.
Also, FieldStorage protects itself better against certain DoS attacks
by limiting the size of the data read in one chunk. The API here
does not support that kind of protection. This also affects parse()
since it can call parse_multipart().
"""
boundary = ""
if 'boundary' in pdict:
boundary = pdict['boundary']
if not valid_boundary(boundary):
raise ValueError, ('Invalid boundary in multipart form: %r'
% (boundary,))
nextpart = "--" + boundary
lastpart = "--" + boundary + "--"
partdict = {}
terminator = ""
while terminator != lastpart:
bytes = -1
data = None
if terminator:
# At start of next part. Read headers first.
headers = mimetools.Message(fp)
clength = headers.getheader('content-length')
if clength:
try:
bytes = int(clength)
except ValueError:
pass
if bytes > 0:
if maxlen and bytes > maxlen:
raise ValueError, 'Maximum content length exceeded'
data = fp.read(bytes)
else:
data = ""
# Read lines until end of part.
lines = []
while 1:
line = fp.readline()
if not line:
terminator = lastpart # End outer loop
break
if line[:2] == "--":
terminator = line.strip()
if terminator in (nextpart, lastpart):
break
lines.append(line)
# Done with part.
if data is None:
continue
if bytes < 0:
if lines:
# Strip final line terminator
line = lines[-1]
if line[-2:] == "\r\n":
line = line[:-2]
elif line[-1:] == "\n":
line = line[:-1]
lines[-1] = line
data = "".join(lines)
line = headers['content-disposition']
if not line:
continue
key, params = parse_header(line)
if key != 'form-data':
continue
if 'name' in params:
name = params['name']
else:
continue
if name in partdict:
partdict[name].append(data)
else:
partdict[name] = [data]
return partdict
def _parseparam(s):
while s[:1] == ';':
s = s[1:]
end = s.find(';')
while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
end = s.find(';', end + 1)
if end < 0:
end = len(s)
f = s[:end]
yield f.strip()
s = s[end:]
def parse_header(line):
"""Parse a Content-type like header.
Return the main content-type and a dictionary of options.
"""
parts = _parseparam(';' + line)
key = parts.next()
pdict = {}
for p in parts:
i = p.find('=')
if i >= 0:
name = p[:i].strip().lower()
value = p[i+1:].strip()
if len(value) >= 2 and value[0] == value[-1] == '"':
value = value[1:-1]
value = value.replace('\\\\', '\\').replace('\\"', '"')
pdict[name] = value
return key, pdict
# Classes for field storage
# =========================
class MiniFieldStorage:
"""Like FieldStorage, for use when no file uploads are possible."""
# Dummy attributes
filename = None
list = None
type = None
file = None
type_options = {}
disposition = None
disposition_options = {}
headers = {}
def __init__(self, name, value):
"""Constructor from field name and value."""
self.name = name
self.value = value
# self.file = StringIO(value)
def __repr__(self):
"""Return printable representation."""
return "MiniFieldStorage(%r, %r)" % (self.name, self.value)
class FieldStorage:
"""Store a sequence of fields, reading multipart/form-data.
This class provides naming, typing, files stored on disk, and
more. At the top level, it is accessible like a dictionary, whose
keys are the field names. (Note: None can occur as a field name.)
The items are either a Python list (if there's multiple values) or
another FieldStorage or MiniFieldStorage object. If it's a single
object, it has the following attributes:
name: the field name, if specified; otherwise None
filename: the filename, if specified; otherwise None; this is the
client side filename, *not* the file name on which it is
stored (that's a temporary file you don't deal with)
value: the value as a *string*; for file uploads, this
transparently reads the file every time you request the value
file: the file(-like) object from which you can read the data;
None if the data is stored a simple string
type: the content-type, or None if not specified
type_options: dictionary of options specified on the content-type
line
disposition: content-disposition, or None if not specified
disposition_options: dictionary of corresponding options
headers: a dictionary(-like) object (sometimes rfc822.Message or a
subclass thereof) containing *all* headers
The class is subclassable, mostly for the purpose of overriding
the make_file() method, which is called internally to come up with
a file open for reading and writing. This makes it possible to
override the default choice of storing all files in a temporary
directory and unlinking them as soon as they have been opened.
"""
def __init__(self, fp=None, headers=None, outerboundary="",
environ=os.environ, keep_blank_values=0, strict_parsing=0,
max_num_fields=None):
"""Constructor. Read multipart/* until last part.
Arguments, all optional:
fp : file pointer; default: sys.stdin
(not used when the request method is GET)
headers : header dictionary-like object; default:
taken from environ as per CGI spec
outerboundary : terminating multipart boundary
(for internal use only)
environ : environment dictionary; default: os.environ
keep_blank_values: flag indicating whether blank values in
percent-encoded forms should be treated as blank strings.
A true value indicates that blanks should be retained as
blank strings. The default false value indicates that
blank values are to be ignored and treated as if they were
not included.
strict_parsing: flag indicating what to do with parsing errors.
If false (the default), errors are silently ignored.
If true, errors raise a ValueError exception.
max_num_fields: int. If set, then __init__ throws a ValueError
if there are more than n fields read by parse_qsl().
"""
method = 'GET'
self.keep_blank_values = keep_blank_values
self.strict_parsing = strict_parsing
self.max_num_fields = max_num_fields
if 'REQUEST_METHOD' in environ:
method = environ['REQUEST_METHOD'].upper()
self.qs_on_post = None
if method == 'GET' or method == 'HEAD':
if 'QUERY_STRING' in environ:
qs = environ['QUERY_STRING']
elif sys.argv[1:]:
qs = sys.argv[1]
else:
qs = ""
fp = StringIO(qs)
if headers is None:
headers = {'content-type':
"application/x-www-form-urlencoded"}
if headers is None:
headers = {}
if method == 'POST':
# Set default content-type for POST to what's traditional
headers['content-type'] = "application/x-www-form-urlencoded"
if 'CONTENT_TYPE' in environ:
headers['content-type'] = environ['CONTENT_TYPE']
if 'QUERY_STRING' in environ:
self.qs_on_post = environ['QUERY_STRING']
if 'CONTENT_LENGTH' in environ:
headers['content-length'] = environ['CONTENT_LENGTH']
self.fp = fp or sys.stdin
self.headers = headers
self.outerboundary = outerboundary
# Process content-disposition header
cdisp, pdict = "", {}
if 'content-disposition' in self.headers:
cdisp, pdict = parse_header(self.headers['content-disposition'])
self.disposition = cdisp
self.disposition_options = pdict
self.name = None
if 'name' in pdict:
self.name = pdict['name']
self.filename = None
if 'filename' in pdict:
self.filename = pdict['filename']
# Process content-type header
#
# Honor any existing content-type header. But if there is no
# content-type header, use some sensible defaults. Assume
# outerboundary is "" at the outer level, but something non-false
# inside a multi-part. The default for an inner part is text/plain,
# but for an outer part it should be urlencoded. This should catch
# bogus clients which erroneously forget to include a content-type
# header.
#
# See below for what we do if there does exist a content-type header,
# but it happens to be something we don't understand.
if 'content-type' in self.headers:
ctype, pdict = parse_header(self.headers['content-type'])
elif self.outerboundary or method != 'POST':
ctype, pdict = "text/plain", {}
else:
ctype, pdict = 'application/x-www-form-urlencoded', {}
self.type = ctype
self.type_options = pdict
self.innerboundary = ""
if 'boundary' in pdict:
self.innerboundary = pdict['boundary']
clen = -1
if 'content-length' in self.headers:
try:
clen = int(self.headers['content-length'])
except ValueError:
pass
if maxlen and clen > maxlen:
raise ValueError, 'Maximum content length exceeded'
self.length = clen
self.list = self.file = None
self.done = 0
if ctype == 'application/x-www-form-urlencoded':
self.read_urlencoded()
elif ctype[:10] == 'multipart/':
self.read_multi(environ, keep_blank_values, strict_parsing)
else:
self.read_single()
def __repr__(self):
"""Return a printable representation."""
return "FieldStorage(%r, %r, %r)" % (
self.name, self.filename, self.value)
def __iter__(self):
return iter(self.keys())
def __getattr__(self, name):
if name != 'value':
raise AttributeError, name
if self.file:
self.file.seek(0)
value = self.file.read()
self.file.seek(0)
elif self.list is not None:
value = self.list
else:
value = None
return value
def __getitem__(self, key):
"""Dictionary style indexing."""
if self.list is None:
raise TypeError, "not indexable"
found = []
for item in self.list:
if item.name == key: found.append(item)
if not found:
raise KeyError, key
if len(found) == 1:
return found[0]
else:
return found
def getvalue(self, key, default=None):
"""Dictionary style get() method, including 'value' lookup."""
if key in self:
value = self[key]
if type(value) is type([]):
return map(attrgetter('value'), value)
else:
return value.value
else:
return default
def getfirst(self, key, default=None):
""" Return the first value received."""
if key in self:
value = self[key]
if type(value) is type([]):
return value[0].value
else:
return value.value
else:
return default
def getlist(self, key):
""" Return list of received values."""
if key in self:
value = self[key]
if type(value) is type([]):
return map(attrgetter('value'), value)
else:
return [value.value]
else:
return []
def keys(self):
"""Dictionary style keys() method."""
if self.list is None:
raise TypeError, "not indexable"
return list(set(item.name for item in self.list))
def has_key(self, key):
"""Dictionary style has_key() method."""
if self.list is None:
raise TypeError, "not indexable"
return any(item.name == key for item in self.list)
def __contains__(self, key):
"""Dictionary style __contains__ method."""
if self.list is None:
raise TypeError, "not indexable"
return any(item.name == key for item in self.list)
def __len__(self):
"""Dictionary style len(x) support."""
return len(self.keys())
def __nonzero__(self):
return bool(self.list)
def read_urlencoded(self):
"""Internal: read data in query string format."""
qs = self.fp.read(self.length)
if self.qs_on_post:
qs += '&' + self.qs_on_post
query = urlparse.parse_qsl(qs, self.keep_blank_values,
self.strict_parsing, self.max_num_fields)
self.list = [MiniFieldStorage(key, value) for key, value in query]
self.skip_lines()
FieldStorageClass = None
def read_multi(self, environ, keep_blank_values, strict_parsing):
"""Internal: read a part that is itself multipart."""
ib = self.innerboundary
if not valid_boundary(ib):
raise ValueError, 'Invalid boundary in multipart form: %r' % (ib,)
self.list = []
if self.qs_on_post:
query = urlparse.parse_qsl(self.qs_on_post,
self.keep_blank_values,
self.strict_parsing,
self.max_num_fields)
self.list.extend(MiniFieldStorage(key, value)
for key, value in query)
FieldStorageClass = None
# Propagate max_num_fields into the sub class appropriately
max_num_fields = self.max_num_fields
if max_num_fields is not None:
max_num_fields -= len(self.list)
klass = self.FieldStorageClass or self.__class__
part = klass(self.fp, {}, ib,
environ, keep_blank_values, strict_parsing,
max_num_fields)
# Throw first part away
while not part.done:
headers = rfc822.Message(self.fp)
part = klass(self.fp, headers, ib,
environ, keep_blank_values, strict_parsing,
max_num_fields)
if max_num_fields is not None:
max_num_fields -= 1
if part.list:
max_num_fields -= len(part.list)
if max_num_fields < 0:
raise ValueError('Max number of fields exceeded')
self.list.append(part)
self.skip_lines()
def read_single(self):
"""Internal: read an atomic part."""
if self.length >= 0:
self.read_binary()
self.skip_lines()
else:
self.read_lines()
self.file.seek(0)
bufsize = 8*1024 # I/O buffering size for copy to file
def read_binary(self):
"""Internal: read binary data."""
self.file = self.make_file('b')
todo = self.length
if todo >= 0:
while todo > 0:
data = self.fp.read(min(todo, self.bufsize))
if not data:
self.done = -1
break
self.file.write(data)
todo = todo - len(data)
def read_lines(self):
"""Internal: read lines until EOF or outerboundary."""
self.file = self.__file = StringIO()
if self.outerboundary:
self.read_lines_to_outerboundary()
else:
self.read_lines_to_eof()
def __write(self, line):
if self.__file is not None:
if self.__file.tell() + len(line) > 1000:
self.file = self.make_file('')
self.file.write(self.__file.getvalue())
self.__file = None
self.file.write(line)
def read_lines_to_eof(self):
"""Internal: read lines until EOF."""
while 1:
line = self.fp.readline(1<<16)
if not line:
self.done = -1
break
self.__write(line)
def read_lines_to_outerboundary(self):
"""Internal: read lines until outerboundary."""
next = "--" + self.outerboundary
last = next + "--"
delim = ""
last_line_lfend = True
while 1:
line = self.fp.readline(1<<16)
if not line:
self.done = -1
break
if delim == "\r":
line = delim + line
delim = ""
if line[:2] == "--" and last_line_lfend:
strippedline = line.strip()
if strippedline == next:
break
if strippedline == last:
self.done = 1
break
odelim = delim
if line[-2:] == "\r\n":
delim = "\r\n"
line = line[:-2]
last_line_lfend = True
elif line[-1] == "\n":
delim = "\n"
line = line[:-1]
last_line_lfend = True
elif line[-1] == "\r":
# We may interrupt \r\n sequences if they span the 2**16
# byte boundary
delim = "\r"
line = line[:-1]
last_line_lfend = False
else:
delim = ""
last_line_lfend = False
self.__write(odelim + line)
def skip_lines(self):
"""Internal: skip lines until outer boundary if defined."""
if not self.outerboundary or self.done:
return
next = "--" + self.outerboundary
last = next + "--"
last_line_lfend = True
while 1:
line = self.fp.readline(1<<16)
if not line:
self.done = -1
break
if line[:2] == "--" and last_line_lfend:
strippedline = line.strip()
if strippedline == next:
break
if strippedline == last:
self.done = 1
break
last_line_lfend = line.endswith('\n')
def make_file(self, binary=None):
"""Overridable: return a readable & writable file.
The file will be used as follows:
- data is written to it
- seek(0)
- data is read from it
The 'binary' argument is unused -- the file is always opened
in binary mode.
This version opens a temporary file for reading and writing,
and immediately deletes (unlinks) it. The trick (on Unix!) is
that the file can still be used, but it can't be opened by
another process, and it will automatically be deleted when it
is closed or when the current process terminates.
If you want a more permanent file, you derive a class which
overrides this method. If you want a visible temporary file
that is nevertheless automatically deleted when the script
terminates, try defining a __del__ method in a derived class
which unlinks the temporary files you have created.
"""
import tempfile
return tempfile.TemporaryFile("w+b")
# Backwards Compatibility Classes
# ===============================
class FormContentDict(UserDict.UserDict):
"""Form content as dictionary with a list of values per field.
form = FormContentDict()
form[key] -> [value, value, ...]
key in form -> Boolean
form.keys() -> [key, key, ...]
form.values() -> [[val, val, ...], [val, val, ...], ...]
form.items() -> [(key, [val, val, ...]), (key, [val, val, ...]), ...]
form.dict == {key: [val, val, ...], ...}
"""
def __init__(self, environ=os.environ, keep_blank_values=0, strict_parsing=0):
self.dict = self.data = parse(environ=environ,
keep_blank_values=keep_blank_values,
strict_parsing=strict_parsing)
self.query_string = environ['QUERY_STRING']
class SvFormContentDict(FormContentDict):
"""Form content as dictionary expecting a single value per field.
If you only expect a single value for each field, then form[key]
will return that single value. It will raise an IndexError if
that expectation is not true. If you expect a field to have
possible multiple values, than you can use form.getlist(key) to
get all of the values. values() and items() are a compromise:
they return single strings where there is a single value, and
lists of strings otherwise.
"""
def __getitem__(self, key):
if len(self.dict[key]) > 1:
raise IndexError, 'expecting a single value'
return self.dict[key][0]
def getlist(self, key):
return self.dict[key]
def values(self):
result = []
for value in self.dict.values():
if len(value) == 1:
result.append(value[0])
else: result.append(value)
return result
def items(self):
result = []
for key, value in self.dict.items():
if len(value) == 1:
result.append((key, value[0]))
else: result.append((key, value))
return result
class InterpFormContentDict(SvFormContentDict):
"""This class is present for backwards compatibility only."""
def __getitem__(self, key):
v = SvFormContentDict.__getitem__(self, key)
if v[0] in '0123456789+-.':
try: return int(v)
except ValueError:
try: return float(v)
except ValueError: pass
return v.strip()
def values(self):
result = []
for key in self.keys():
try:
result.append(self[key])
except IndexError:
result.append(self.dict[key])
return result
def items(self):
result = []
for key in self.keys():
try:
result.append((key, self[key]))
except IndexError:
result.append((key, self.dict[key]))
return result
class FormContent(FormContentDict):
"""This class is present for backwards compatibility only."""
def values(self, key):
if key in self.dict :return self.dict[key]
else: return None
def indexed_value(self, key, location):
if key in self.dict:
if len(self.dict[key]) > location:
return self.dict[key][location]
else: return None
else: return None
def value(self, key):
if key in self.dict: return self.dict[key][0]
else: return None
def length(self, key):
return len(self.dict[key])
def stripped(self, key):
if key in self.dict: return self.dict[key][0].strip()
else: return None
def pars(self):
return self.dict
# Test/debug code
# ===============
def test(environ=os.environ):
"""Robust test CGI script, usable as main program.
Write minimal HTTP headers and dump all information provided to
the script in HTML form.
"""
print "Content-type: text/html"
print
sys.stderr = sys.stdout
try:
form = FieldStorage() # Replace with other classes to test those
print_directory()
print_arguments()
print_form(form)
print_environ(environ)
print_environ_usage()
def f():
exec "testing print_exception() -- <I>italics?</I>"
def g(f=f):
f()
print "<H3>What follows is a test, not an actual exception:</H3>"
g()
except:
print_exception()
print "<H1>Second try with a small maxlen...</H1>"
global maxlen
maxlen = 50
try:
form = FieldStorage() # Replace with other classes to test those
print_directory()
print_arguments()
print_form(form)
print_environ(environ)
except:
print_exception()
def print_exception(type=None, value=None, tb=None, limit=None):
if type is None:
type, value, tb = sys.exc_info()
import traceback
print
print "<H3>Traceback (most recent call last):</H3>"
list = traceback.format_tb(tb, limit) + \
traceback.format_exception_only(type, value)
print "<PRE>%s<B>%s</B></PRE>" % (
escape("".join(list[:-1])),
escape(list[-1]),
)
del tb
def print_environ(environ=os.environ):
"""Dump the shell environment as HTML."""
keys = environ.keys()
keys.sort()
print
print "<H3>Shell Environment:</H3>"
print "<DL>"
for key in keys:
print "<DT>", escape(key), "<DD>", escape(environ[key])
print "</DL>"
print
def print_form(form):
"""Dump the contents of a form as HTML."""
keys = form.keys()
keys.sort()
print
print "<H3>Form Contents:</H3>"
if not keys:
print "<P>No form fields."
print "<DL>"
for key in keys:
print "<DT>" + escape(key) + ":",
value = form[key]
print "<i>" + escape(repr(type(value))) + "</i>"
print "<DD>" + escape(repr(value))
print "</DL>"
print
def print_directory():
"""Dump the current directory as HTML."""
print
print "<H3>Current Working Directory:</H3>"
try:
pwd = os.getcwd()
except os.error, msg:
print "os.error:", escape(str(msg))
else:
print escape(pwd)
print
def print_arguments():
print
print "<H3>Command Line Arguments:</H3>"
print
print sys.argv
print
def print_environ_usage():
"""Dump a list of environment variables used by CGI as HTML."""
print """
<H3>These environment variables could have been set:</H3>
<UL>
<LI>AUTH_TYPE
<LI>CONTENT_LENGTH
<LI>CONTENT_TYPE
<LI>DATE_GMT
<LI>DATE_LOCAL
<LI>DOCUMENT_NAME
<LI>DOCUMENT_ROOT
<LI>DOCUMENT_URI
<LI>GATEWAY_INTERFACE
<LI>LAST_MODIFIED
<LI>PATH
<LI>PATH_INFO
<LI>PATH_TRANSLATED
<LI>QUERY_STRING
<LI>REMOTE_ADDR
<LI>REMOTE_HOST
<LI>REMOTE_IDENT
<LI>REMOTE_USER
<LI>REQUEST_METHOD
<LI>SCRIPT_NAME
<LI>SERVER_NAME
<LI>SERVER_PORT
<LI>SERVER_PROTOCOL
<LI>SERVER_ROOT
<LI>SERVER_SOFTWARE
</UL>
In addition, HTTP headers sent by the server may be passed in the
environment as well. Here are some common variable names:
<UL>
<LI>HTTP_ACCEPT
<LI>HTTP_CONNECTION
<LI>HTTP_HOST
<LI>HTTP_PRAGMA
<LI>HTTP_REFERER
<LI>HTTP_USER_AGENT
</UL>
"""
# Utilities
# =========
def escape(s, quote=None):
'''Replace special characters "&", "<" and ">" to HTML-safe sequences.
If the optional flag quote is true, the quotation mark character (")
is also translated.'''
s = s.replace("&", "&") # Must be done first!
s = s.replace("<", "<")
s = s.replace(">", ">")
if quote:
s = s.replace('"', """)
return s
def valid_boundary(s, _vb_pattern="^[ -~]{0,200}[!-~]$"):
import re
return re.match(_vb_pattern, s)
# Invoke mainline
# ===============
# Call test() when this file is run as a script (not imported as a module)
if __name__ == '__main__':
test()
"""CGI-savvy HTTP Server.
This module builds on SimpleHTTPServer by implementing GET and POST
requests to cgi-bin scripts.
If the os.fork() function is not present (e.g. on Windows),
os.popen2() is used as a fallback, with slightly altered semantics; if
that function is not present either (e.g. on Macintosh), only Python
scripts are supported, and they are executed by the current process.
In all cases, the implementation is intentionally naive -- all
requests are executed sychronously.
SECURITY WARNING: DON'T USE THIS CODE UNLESS YOU ARE INSIDE A FIREWALL
-- it may execute arbitrary Python code or external programs.
Note that status code 200 is sent prior to execution of a CGI script, so
scripts cannot send other status codes such as 302 (redirect).
"""
__version__ = "0.4"
__all__ = ["CGIHTTPRequestHandler"]
import os
import sys
import urllib
import BaseHTTPServer
import SimpleHTTPServer
import select
import copy
class CGIHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""Complete HTTP server with GET, HEAD and POST commands.
GET and HEAD also support running CGI scripts.
The POST command is *only* implemented for CGI scripts.
"""
# Determine platform specifics
have_fork = hasattr(os, 'fork')
have_popen2 = hasattr(os, 'popen2')
have_popen3 = hasattr(os, 'popen3')
# Make rfile unbuffered -- we need to read one line and then pass
# the rest to a subprocess, so we can't use buffered input.
rbufsize = 0
def do_POST(self):
"""Serve a POST request.
This is only implemented for CGI scripts.
"""
if self.is_cgi():
self.run_cgi()
else:
self.send_error(501, "Can only POST to CGI scripts")
def send_head(self):
"""Version of send_head that support CGI scripts"""
if self.is_cgi():
return self.run_cgi()
else:
return SimpleHTTPServer.SimpleHTTPRequestHandler.send_head(self)
def is_cgi(self):
"""Test whether self.path corresponds to a CGI script.
Returns True and updates the cgi_info attribute to the tuple
(dir, rest) if self.path requires running a CGI script.
Returns False otherwise.
If any exception is raised, the caller should assume that
self.path was rejected as invalid and act accordingly.
The default implementation tests whether the normalized url
path begins with one of the strings in self.cgi_directories
(and the next character is a '/' or the end of the string).
"""
collapsed_path = _url_collapse_path(self.path)
dir_sep = collapsed_path.find('/', 1)
head, tail = collapsed_path[:dir_sep], collapsed_path[dir_sep+1:]
if head in self.cgi_directories:
self.cgi_info = head, tail
return True
return False
cgi_directories = ['/cgi-bin', '/htbin']
def is_executable(self, path):
"""Test whether argument path is an executable file."""
return executable(path)
def is_python(self, path):
"""Test whether argument path is a Python script."""
head, tail = os.path.splitext(path)
return tail.lower() in (".py", ".pyw")
def run_cgi(self):
"""Execute a CGI script."""
dir, rest = self.cgi_info
path = dir + '/' + rest
i = path.find('/', len(dir)+1)
while i >= 0:
nextdir = path[:i]
nextrest = path[i+1:]
scriptdir = self.translate_path(nextdir)
if os.path.isdir(scriptdir):
dir, rest = nextdir, nextrest
i = path.find('/', len(dir)+1)
else:
break
# find an explicit query string, if present.
rest, _, query = rest.partition('?')
# dissect the part after the directory name into a script name &
# a possible additional path, to be stored in PATH_INFO.
i = rest.find('/')
if i >= 0:
script, rest = rest[:i], rest[i:]
else:
script, rest = rest, ''
scriptname = dir + '/' + script
scriptfile = self.translate_path(scriptname)
if not os.path.exists(scriptfile):
self.send_error(404, "No such CGI script (%r)" % scriptname)
return
if not os.path.isfile(scriptfile):
self.send_error(403, "CGI script is not a plain file (%r)" %
scriptname)
return
ispy = self.is_python(scriptname)
if not ispy:
if not (self.have_fork or self.have_popen2 or self.have_popen3):
self.send_error(403, "CGI script is not a Python script (%r)" %
scriptname)
return
if not self.is_executable(scriptfile):
self.send_error(403, "CGI script is not executable (%r)" %
scriptname)
return
# Reference: http://hoohoo.ncsa.uiuc.edu/cgi/env.html
# XXX Much of the following could be prepared ahead of time!
env = copy.deepcopy(os.environ)
env['SERVER_SOFTWARE'] = self.version_string()
env['SERVER_NAME'] = self.server.server_name
env['GATEWAY_INTERFACE'] = 'CGI/1.1'
env['SERVER_PROTOCOL'] = self.protocol_version
env['SERVER_PORT'] = str(self.server.server_port)
env['REQUEST_METHOD'] = self.command
uqrest = urllib.unquote(rest)
env['PATH_INFO'] = uqrest
env['PATH_TRANSLATED'] = self.translate_path(uqrest)
env['SCRIPT_NAME'] = scriptname
if query:
env['QUERY_STRING'] = query
host = self.address_string()
if host != self.client_address[0]:
env['REMOTE_HOST'] = host
env['REMOTE_ADDR'] = self.client_address[0]
authorization = self.headers.getheader("authorization")
if authorization:
authorization = authorization.split()
if len(authorization) == 2:
import base64, binascii
env['AUTH_TYPE'] = authorization[0]
if authorization[0].lower() == "basic":
try:
authorization = base64.decodestring(authorization[1])
except binascii.Error:
pass
else:
authorization = authorization.split(':')
if len(authorization) == 2:
env['REMOTE_USER'] = authorization[0]
# XXX REMOTE_IDENT
if self.headers.typeheader is None:
env['CONTENT_TYPE'] = self.headers.type
else:
env['CONTENT_TYPE'] = self.headers.typeheader
length = self.headers.getheader('content-length')
if length:
env['CONTENT_LENGTH'] = length
referer = self.headers.getheader('referer')
if referer:
env['HTTP_REFERER'] = referer
accept = []
for line in self.headers.getallmatchingheaders('accept'):
if line[:1] in "\t\n\r ":
accept.append(line.strip())
else:
accept = accept + line[7:].split(',')
env['HTTP_ACCEPT'] = ','.join(accept)
ua = self.headers.getheader('user-agent')
if ua:
env['HTTP_USER_AGENT'] = ua
co = filter(None, self.headers.getheaders('cookie'))
if co:
env['HTTP_COOKIE'] = ', '.join(co)
# XXX Other HTTP_* headers
# Since we're setting the env in the parent, provide empty
# values to override previously set values
for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
env.setdefault(k, "")
self.send_response(200, "Script output follows")
decoded_query = query.replace('+', ' ')
if self.have_fork:
# Unix -- fork as we should
args = [script]
if '=' not in decoded_query:
args.append(decoded_query)
nobody = nobody_uid()
self.wfile.flush() # Always flush before forking
pid = os.fork()
if pid != 0:
# Parent
pid, sts = os.waitpid(pid, 0)
# throw away additional data [see bug #427345]
while select.select([self.rfile], [], [], 0)[0]:
if not self.rfile.read(1):
break
if sts:
self.log_error("CGI script exit status %#x", sts)
return
# Child
try:
try:
os.setuid(nobody)
except os.error:
pass
os.dup2(self.rfile.fileno(), 0)
os.dup2(self.wfile.fileno(), 1)
os.execve(scriptfile, args, env)
except:
self.server.handle_error(self.request, self.client_address)
os._exit(127)
else:
# Non Unix - use subprocess
import subprocess
cmdline = [scriptfile]
if self.is_python(scriptfile):
interp = sys.executable
if interp.lower().endswith("w.exe"):
# On Windows, use python.exe, not pythonw.exe
interp = interp[:-5] + interp[-4:]
cmdline = [interp, '-u'] + cmdline
if '=' not in query:
cmdline.append(query)
self.log_message("command: %s", subprocess.list2cmdline(cmdline))
try:
nbytes = int(length)
except (TypeError, ValueError):
nbytes = 0
p = subprocess.Popen(cmdline,
stdin = subprocess.PIPE,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
env = env
)
if self.command.lower() == "post" and nbytes > 0:
data = self.rfile.read(nbytes)
else:
data = None
# throw away additional data [see bug #427345]
while select.select([self.rfile._sock], [], [], 0)[0]:
if not self.rfile._sock.recv(1):
break
stdout, stderr = p.communicate(data)
self.wfile.write(stdout)
if stderr:
self.log_error('%s', stderr)
p.stderr.close()
p.stdout.close()
status = p.returncode
if status:
self.log_error("CGI script exit status %#x", status)
else:
self.log_message("CGI script exited OK")
def _url_collapse_path(path):
"""
Given a URL path, remove extra '/'s and '.' path elements and collapse
any '..' references and returns a colllapsed path.
Implements something akin to RFC-2396 5.2 step 6 to parse relative paths.
The utility of this function is limited to is_cgi method and helps
preventing some security attacks.
Returns: The reconstituted URL, which will always start with a '/'.
Raises: IndexError if too many '..' occur within the path.
"""
# Query component should not be involved.
path, _, query = path.partition('?')
path = urllib.unquote(path)
# Similar to os.path.split(os.path.normpath(path)) but specific to URL
# path semantics rather than local operating system semantics.
path_parts = path.split('/')
head_parts = []
for part in path_parts[:-1]:
if part == '..':
head_parts.pop() # IndexError if more '..' than prior parts
elif part and part != '.':
head_parts.append( part )
if path_parts:
tail_part = path_parts.pop()
if tail_part:
if tail_part == '..':
head_parts.pop()
tail_part = ''
elif tail_part == '.':
tail_part = ''
else:
tail_part = ''
if query:
tail_part = '?'.join((tail_part, query))
splitpath = ('/' + '/'.join(head_parts), tail_part)
collapsed_path = "/".join(splitpath)
return collapsed_path
nobody = None
def nobody_uid():
"""Internal routine to get nobody's uid"""
global nobody
if nobody:
return nobody
try:
import pwd
except ImportError:
return -1
try:
nobody = pwd.getpwnam('nobody')[2]
except KeyError:
nobody = 1 + max(map(lambda x: x[2], pwd.getpwall()))
return nobody
def executable(path):
"""Test for executable file."""
try:
st = os.stat(path)
except os.error:
return False
return st.st_mode & 0111 != 0
def test(HandlerClass = CGIHTTPRequestHandler,
ServerClass = BaseHTTPServer.HTTPServer):
SimpleHTTPServer.test(HandlerClass, ServerClass)
if __name__ == '__main__':
test()
"""More comprehensive traceback formatting for Python scripts.
To enable this module, do:
import cgitb; cgitb.enable()
at the top of your script. The optional arguments to enable() are:
display - if true, tracebacks are displayed in the web browser
logdir - if set, tracebacks are written to files in this directory
context - number of lines of source code to show for each stack frame
format - 'text' or 'html' controls the output format
By default, tracebacks are displayed but not saved, the context is 5 lines
and the output format is 'html' (for backwards compatibility with the
original use of this module)
Alternatively, if you have caught an exception and want cgitb to display it
for you, call cgitb.handler(). The optional argument to handler() is a
3-item tuple (etype, evalue, etb) just like the value of sys.exc_info().
The default handler displays output as HTML.
"""
import inspect
import keyword
import linecache
import os
import pydoc
import sys
import tempfile
import time
import tokenize
import traceback
import types
def reset():
"""Return a string that resets the CGI and browser to a known state."""
return '''<!--: spam
Content-Type: text/html
<body bgcolor="#f0f0f8"><font color="#f0f0f8" size="-5"> -->
<body bgcolor="#f0f0f8"><font color="#f0f0f8" size="-5"> --> -->
</font> </font> </font> </script> </object> </blockquote> </pre>
</table> </table> </table> </table> </table> </font> </font> </font>'''
__UNDEF__ = [] # a special sentinel object
def small(text):
if text:
return '<small>' + text + '</small>'
else:
return ''
def strong(text):
if text:
return '<strong>' + text + '</strong>'
else:
return ''
def grey(text):
if text:
return '<font color="#909090">' + text + '</font>'
else:
return ''
def lookup(name, frame, locals):
"""Find the value for a given name in the given environment."""
if name in locals:
return 'local', locals[name]
if name in frame.f_globals:
return 'global', frame.f_globals[name]
if '__builtins__' in frame.f_globals:
builtins = frame.f_globals['__builtins__']
if type(builtins) is type({}):
if name in builtins:
return 'builtin', builtins[name]
else:
if hasattr(builtins, name):
return 'builtin', getattr(builtins, name)
return None, __UNDEF__
def scanvars(reader, frame, locals):
"""Scan one logical line of Python and look up values of variables used."""
vars, lasttoken, parent, prefix, value = [], None, None, '', __UNDEF__
for ttype, token, start, end, line in tokenize.generate_tokens(reader):
if ttype == tokenize.NEWLINE: break
if ttype == tokenize.NAME and token not in keyword.kwlist:
if lasttoken == '.':
if parent is not __UNDEF__:
value = getattr(parent, token, __UNDEF__)
vars.append((prefix + token, prefix, value))
else:
where, value = lookup(token, frame, locals)
vars.append((token, where, value))
elif token == '.':
prefix += lasttoken + '.'
parent = value
else:
parent, prefix = None, ''
lasttoken = token
return vars
def html(einfo, context=5):
"""Return a nice HTML document describing a given traceback."""
etype, evalue, etb = einfo
if type(etype) is types.ClassType:
etype = etype.__name__
pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
date = time.ctime(time.time())
head = '<body bgcolor="#f0f0f8">' + pydoc.html.heading(
'<big><big>%s</big></big>' %
strong(pydoc.html.escape(str(etype))),
'#ffffff', '#6622aa', pyver + '<br>' + date) + '''
<p>A problem occurred in a Python script. Here is the sequence of
function calls leading up to the error, in the order they occurred.</p>'''
indent = '<tt>' + small(' ' * 5) + ' </tt>'
frames = []
records = inspect.getinnerframes(etb, context)
for frame, file, lnum, func, lines, index in records:
if file:
file = os.path.abspath(file)
link = '<a href="file://%s">%s</a>' % (file, pydoc.html.escape(file))
else:
file = link = '?'
args, varargs, varkw, locals = inspect.getargvalues(frame)
call = ''
if func != '?':
call = 'in ' + strong(pydoc.html.escape(func)) + \
inspect.formatargvalues(args, varargs, varkw, locals,
formatvalue=lambda value: '=' + pydoc.html.repr(value))
highlight = {}
def reader(lnum=[lnum]):
highlight[lnum[0]] = 1
try: return linecache.getline(file, lnum[0])
finally: lnum[0] += 1
vars = scanvars(reader, frame, locals)
rows = ['<tr><td bgcolor="#d8bbff">%s%s %s</td></tr>' %
('<big> </big>', link, call)]
if index is not None:
i = lnum - index
for line in lines:
num = small(' ' * (5-len(str(i))) + str(i)) + ' '
if i in highlight:
line = '<tt>=>%s%s</tt>' % (num, pydoc.html.preformat(line))
rows.append('<tr><td bgcolor="#ffccee">%s</td></tr>' % line)
else:
line = '<tt> %s%s</tt>' % (num, pydoc.html.preformat(line))
rows.append('<tr><td>%s</td></tr>' % grey(line))
i += 1
done, dump = {}, []
for name, where, value in vars:
if name in done: continue
done[name] = 1
if value is not __UNDEF__:
if where in ('global', 'builtin'):
name = ('<em>%s</em> ' % where) + strong(name)
elif where == 'local':
name = strong(name)
else:
name = where + strong(name.split('.')[-1])
dump.append('%s = %s' % (name, pydoc.html.repr(value)))
else:
dump.append(name + ' <em>undefined</em>')
rows.append('<tr><td>%s</td></tr>' % small(grey(', '.join(dump))))
frames.append('''
<table width="100%%" cellspacing=0 cellpadding=0 border=0>
%s</table>''' % '\n'.join(rows))
exception = ['<p>%s: %s' % (strong(pydoc.html.escape(str(etype))),
pydoc.html.escape(str(evalue)))]
if isinstance(evalue, BaseException):
for name in dir(evalue):
if name[:1] == '_': continue
value = pydoc.html.repr(getattr(evalue, name))
exception.append('\n<br>%s%s =\n%s' % (indent, name, value))
return head + ''.join(frames) + ''.join(exception) + '''
<!-- The above is a description of an error in a Python program, formatted
for a Web browser because the 'cgitb' module was enabled. In case you
are not reading this in a Web browser, here is the original traceback:
%s
-->
''' % pydoc.html.escape(
''.join(traceback.format_exception(etype, evalue, etb)))
def text(einfo, context=5):
"""Return a plain text document describing a given traceback."""
etype, evalue, etb = einfo
if type(etype) is types.ClassType:
etype = etype.__name__
pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
date = time.ctime(time.time())
head = "%s\n%s\n%s\n" % (str(etype), pyver, date) + '''
A problem occurred in a Python script. Here is the sequence of
function calls leading up to the error, in the order they occurred.
'''
frames = []
records = inspect.getinnerframes(etb, context)
for frame, file, lnum, func, lines, index in records:
file = file and os.path.abspath(file) or '?'
args, varargs, varkw, locals = inspect.getargvalues(frame)
call = ''
if func != '?':
call = 'in ' + func + \
inspect.formatargvalues(args, varargs, varkw, locals,
formatvalue=lambda value: '=' + pydoc.text.repr(value))
highlight = {}
def reader(lnum=[lnum]):
highlight[lnum[0]] = 1
try: return linecache.getline(file, lnum[0])
finally: lnum[0] += 1
vars = scanvars(reader, frame, locals)
rows = [' %s %s' % (file, call)]
if index is not None:
i = lnum - index
for line in lines:
num = '%5d ' % i
rows.append(num+line.rstrip())
i += 1
done, dump = {}, []
for name, where, value in vars:
if name in done: continue
done[name] = 1
if value is not __UNDEF__:
if where == 'global': name = 'global ' + name
elif where != 'local': name = where + name.split('.')[-1]
dump.append('%s = %s' % (name, pydoc.text.repr(value)))
else:
dump.append(name + ' undefined')
rows.append('\n'.join(dump))
frames.append('\n%s\n' % '\n'.join(rows))
exception = ['%s: %s' % (str(etype), str(evalue))]
if isinstance(evalue, BaseException):
for name in dir(evalue):
value = pydoc.text.repr(getattr(evalue, name))
exception.append('\n%s%s = %s' % (" "*4, name, value))
return head + ''.join(frames) + ''.join(exception) + '''
The above is a description of an error in a Python program. Here is
the original traceback:
%s
''' % ''.join(traceback.format_exception(etype, evalue, etb))
class Hook:
"""A hook to replace sys.excepthook that shows tracebacks in HTML."""
def __init__(self, display=1, logdir=None, context=5, file=None,
format="html"):
self.display = display # send tracebacks to browser if true
self.logdir = logdir # log tracebacks to files if not None
self.context = context # number of source code lines per frame
self.file = file or sys.stdout # place to send the output
self.format = format
def __call__(self, etype, evalue, etb):
self.handle((etype, evalue, etb))
def handle(self, info=None):
info = info or sys.exc_info()
if self.format == "html":
self.file.write(reset())
formatter = (self.format=="html") and html or text
plain = False
try:
doc = formatter(info, self.context)
except: # just in case something goes wrong
doc = ''.join(traceback.format_exception(*info))
plain = True
if self.display:
if plain:
doc = pydoc.html.escape(doc)
self.file.write('<pre>' + doc + '</pre>\n')
else:
self.file.write(doc + '\n')
else:
self.file.write('<p>A problem occurred in a Python script.\n')
if self.logdir is not None:
suffix = ['.txt', '.html'][self.format=="html"]
(fd, path) = tempfile.mkstemp(suffix=suffix, dir=self.logdir)
try:
file = os.fdopen(fd, 'w')
file.write(doc)
file.close()
msg = '%s contains the description of this error.' % path
except:
msg = 'Tried to save traceback to %s, but failed.' % path
if self.format == 'html':
self.file.write('<p>%s</p>\n' % msg)
else:
self.file.write(msg + '\n')
try:
self.file.flush()
except: pass
handler = Hook().handle
def enable(display=1, logdir=None, context=5, format="html"):
"""Install an exception handler that formats tracebacks as HTML.
The optional argument 'display' can be set to 0 to suppress sending the
traceback to the browser, and 'logdir' can be set to a directory to cause
tracebacks to be written to files there."""
sys.excepthook = Hook(display=display, logdir=logdir,
context=context, format=format)
"""Simple class to read IFF chunks.
An IFF chunk (used in formats such as AIFF, TIFF, RMFF (RealMedia File
Format)) has the following structure:
+----------------+
| ID (4 bytes) |
+----------------+
| size (4 bytes) |
+----------------+
| data |
| ... |
+----------------+
The ID is a 4-byte string which identifies the type of chunk.
The size field (a 32-bit value, encoded using big-endian byte order)
gives the size of the whole chunk, including the 8-byte header.
Usually an IFF-type file consists of one or more chunks. The proposed
usage of the Chunk class defined here is to instantiate an instance at
the start of each chunk and read from the instance until it reaches
the end, after which a new instance can be instantiated. At the end
of the file, creating a new instance will fail with an EOFError
exception.
Usage:
while True:
try:
chunk = Chunk(file)
except EOFError:
break
chunktype = chunk.getname()
while True:
data = chunk.read(nbytes)
if not data:
pass
# do something with data
The interface is file-like. The implemented methods are:
read, close, seek, tell, isatty.
Extra methods are: skip() (called by close, skips to the end of the chunk),
getname() (returns the name (ID) of the chunk)
The __init__ method has one required argument, a file-like object
(including a chunk instance), and one optional argument, a flag which
specifies whether or not chunks are aligned on 2-byte boundaries. The
default is 1, i.e. aligned.
"""
class Chunk:
def __init__(self, file, align=True, bigendian=True, inclheader=False):
import struct
self.closed = False
self.align = align # whether to align to word (2-byte) boundaries
if bigendian:
strflag = '>'
else:
strflag = '<'
self.file = file
self.chunkname = file.read(4)
if len(self.chunkname) < 4:
raise EOFError
try:
self.chunksize = struct.unpack(strflag+'L', file.read(4))[0]
except struct.error:
raise EOFError
if inclheader:
self.chunksize = self.chunksize - 8 # subtract header
self.size_read = 0
try:
self.offset = self.file.tell()
except (AttributeError, IOError):
self.seekable = False
else:
self.seekable = True
def getname(self):
"""Return the name (ID) of the current chunk."""
return self.chunkname
def getsize(self):
"""Return the size of the current chunk."""
return self.chunksize
def close(self):
if not self.closed:
try:
self.skip()
finally:
self.closed = True
def isatty(self):
if self.closed:
raise ValueError, "I/O operation on closed file"
return False
def seek(self, pos, whence=0):
"""Seek to specified position into the chunk.
Default position is 0 (start of chunk).
If the file is not seekable, this will result in an error.
"""
if self.closed:
raise ValueError, "I/O operation on closed file"
if not self.seekable:
raise IOError, "cannot seek"
if whence == 1:
pos = pos + self.size_read
elif whence == 2:
pos = pos + self.chunksize
if pos < 0 or pos > self.chunksize:
raise RuntimeError
self.file.seek(self.offset + pos, 0)
self.size_read = pos
def tell(self):
if self.closed:
raise ValueError, "I/O operation on closed file"
return self.size_read
def read(self, size=-1):
"""Read at most size bytes from the chunk.
If size is omitted or negative, read until the end
of the chunk.
"""
if self.closed:
raise ValueError, "I/O operation on closed file"
if self.size_read >= self.chunksize:
return ''
if size < 0:
size = self.chunksize - self.size_read
if size > self.chunksize - self.size_read:
size = self.chunksize - self.size_read
data = self.file.read(size)
self.size_read = self.size_read + len(data)
if self.size_read == self.chunksize and \
self.align and \
(self.chunksize & 1):
dummy = self.file.read(1)
self.size_read = self.size_read + len(dummy)
return data
def skip(self):
"""Skip the rest of the chunk.
If you are not interested in the contents of the chunk,
this method should be called so that the file points to
the start of the next chunk.
"""
if self.closed:
raise ValueError, "I/O operation on closed file"
if self.seekable:
try:
n = self.chunksize - self.size_read
# maybe fix alignment
if self.align and (self.chunksize & 1):
n = n + 1
self.file.seek(n, 1)
self.size_read = self.size_read + n
return
except IOError:
pass
while self.size_read < self.chunksize:
n = min(8192, self.chunksize - self.size_read)
dummy = self.read(n)
if not dummy:
raise EOFError
# Licensed to the .NET Foundation under one or more agreements.
# The .NET Foundation licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information.
__all__ = ["ClrClass", "ClrInterface", "accepts", "returns", "attribute", "propagate_attributes"]
import clr
clr.AddReference("Microsoft.Dynamic")
clr.AddReference("Microsoft.Scripting")
clr.AddReference("IronPython")
if clr.IsNetCoreApp:
clr.AddReference("System.Reflection.Emit")
import System
from System import Char, Void, Boolean, Array, Type, AppDomain
from System.Reflection import FieldAttributes, MethodAttributes, PropertyAttributes, ParameterAttributes
from System.Reflection import CallingConventions, TypeAttributes, AssemblyName
from System.Reflection.Emit import OpCodes, CustomAttributeBuilder, AssemblyBuilder, AssemblyBuilderAccess
from System.Runtime.InteropServices import DllImportAttribute, CallingConvention, CharSet
from Microsoft.Scripting.Generation import Snippets
from Microsoft.Scripting.Runtime import DynamicOperations
from Microsoft.Scripting.Utils import ReflectionUtils
from IronPython.Runtime import NameType, PythonContext
from IronPython.Runtime.Types import PythonType, ReflectedField, ReflectedProperty
def validate_clr_types(signature_types, var_signature = False):
if not isinstance(signature_types, tuple):
signature_types = (signature_types,)
for t in signature_types:
if type(t) is type(System.IComparable): # type overloaded on generic arity, eg IComparable and IComparable[T]
t = t[()] # select non-generic version
clr_type = clr.GetClrType(t)
if t == Void:
raise TypeError("Void cannot be used in signature")
is_typed = clr.GetPythonType(clr_type) == t
# is_typed needs to be weakened until the generated type
# gets explicitly published as the underlying CLR type
is_typed = is_typed or (hasattr(t, "__metaclass__") and t.__metaclass__ in [ClrInterface, ClrClass])
if not is_typed:
raise Exception, "Invalid CLR type %s" % str(t)
if not var_signature:
if clr_type.IsByRef:
raise TypeError("Byref can only be used as arguments and locals")
# ArgIterator is not present in Silverlight
if hasattr(System, "ArgIterator") and t == System.ArgIterator:
raise TypeError("Stack-referencing types can only be used as arguments and locals")
class TypedFunction(object):
"""
A strongly-typed function can get wrapped up as a staticmethod, a property, etc.
This class represents the raw function, but with the type information
it is decorated with.
Other information is stored as attributes on the function. See propagate_attributes
"""
def __init__(self, function, is_static = False, prop_name_if_prop_get = None, prop_name_if_prop_set = None):
self.function = function
self.is_static = is_static
self.prop_name_if_prop_get = prop_name_if_prop_get
self.prop_name_if_prop_set = prop_name_if_prop_set
class ClrType(type):
"""
Base metaclass for creating strongly-typed CLR types
"""
def is_typed_method(self, function):
if hasattr(function, "arg_types") != hasattr(function, "return_type"):
raise TypeError("One of @accepts and @returns is missing for %s" % function.func_name)
return hasattr(function, "arg_types")
def get_typed_properties(self):
for item_name, item in self.__dict__.items():
if isinstance(item, property):
if item.fget:
if not self.is_typed_method(item.fget): continue
prop_type = item.fget.return_type
else:
if not self.is_typed_method(item.fset): continue
prop_type = item.fset.arg_types[0]
validate_clr_types(prop_type)
clr_prop_type = clr.GetClrType(prop_type)
yield item, item_name, clr_prop_type
def emit_properties(self, typebld):
for prop, prop_name, clr_prop_type in self.get_typed_properties():
self.emit_property(typebld, prop, prop_name, clr_prop_type)
def emit_property(self, typebld, prop, name, clrtype):
prpbld = typebld.DefineProperty(name, PropertyAttributes.None, clrtype, None)
if prop.fget:
getter = self.emitted_methods[(prop.fget.func_name, prop.fget.arg_types)]
prpbld.SetGetMethod(getter)
if prop.fset:
setter = self.emitted_methods[(prop.fset.func_name, prop.fset.arg_types)]
prpbld.SetSetMethod(setter)
def dummy_function(self): raise RuntimeError("this should not get called")
def get_typed_methods(self):
"""
Get all the methods with @accepts (and @returns) decorators
Functions are assumed to be instance methods, unless decorated with @staticmethod
"""
# We avoid using the "types" library as it is not a builtin
FunctionType = type(ClrType.__dict__["dummy_function"])
for item_name, item in self.__dict__.items():
function = None
is_static = False
if isinstance(item, FunctionType):
function, is_static = item, False
elif isinstance(item, staticmethod):
function, is_static = getattr(self, item_name), True
elif isinstance(item, property):
if item.fget and self.is_typed_method(item.fget):
if item.fget.func_name == item_name:
# The property hides the getter. So yield the getter
yield TypedFunction(item.fget, False, item_name, None)
if item.fset and self.is_typed_method(item.fset):
if item.fset.func_name == item_name:
# The property hides the setter. So yield the setter
yield TypedFunction(item.fset, False, None, item_name)
continue
else:
continue
if self.is_typed_method(function):
yield TypedFunction(function, is_static)
def emit_methods(self, typebld):
# We need to track the generated methods so that we can emit properties
# referring these methods.
# Also, the hash is indexed by name *and signature*. Even though Python does
# not have method overloading, property getter and setter functions can have
# the same func_name attribute
self.emitted_methods = {}
for function_info in self.get_typed_methods():
method_builder = self.emit_method(typebld, function_info)
function = function_info.function
if self.emitted_methods.has_key((function.func_name, function.arg_types)):
raise TypeError("methods with clashing names")
self.emitted_methods[(function.func_name, function.arg_types)] = method_builder
def emit_classattribs(self, typebld):
if hasattr(self, '_clrclassattribs'):
for attrib_info in self._clrclassattribs:
if isinstance(attrib_info, type):
ci = clr.GetClrType(attrib_info).GetConstructor(())
cab = CustomAttributeBuilder(ci, ())
elif isinstance(attrib_info, CustomAttributeDecorator):
cab = attrib_info.GetBuilder()
else:
make_decorator = attrib_info()
cab = make_decorator.GetBuilder()
typebld.SetCustomAttribute(cab)
def get_clr_type_name(self):
if hasattr(self, "_clrnamespace"):
return self._clrnamespace + "." + self.__name__
else:
return self.__name__
def create_type(self, typebld):
self.emit_members(typebld)
new_type = typebld.CreateType()
self.map_members(new_type)
return new_type
class ClrInterface(ClrType):
"""
Set __metaclass__ in a Python class declaration to declare a
CLR interface type.
You need to specify object as the base-type if you do not specify any other
interfaces as the base interfaces
"""
def __init__(self, *args):
return super(ClrInterface, self).__init__(*args)
def emit_method(self, typebld, function_info):
assert(not function_info.is_static)
function = function_info.function
attributes = MethodAttributes.Public | MethodAttributes.Virtual | MethodAttributes.Abstract
method_builder = typebld.DefineMethod(
function.func_name,
attributes,
function.return_type,
function.arg_types)
instance_offset = 0 if function_info.is_static else 1
arg_names = function.func_code.co_varnames
for i in xrange(len(function.arg_types)):
# TODO - set non-trivial ParameterAttributes, default value and custom attributes
p = method_builder.DefineParameter(i + 1, ParameterAttributes.None, arg_names[i + instance_offset])
if hasattr(function, "CustomAttributeBuilders"):
for cab in function.CustomAttributeBuilders:
method_builder.SetCustomAttribute(cab)
return method_builder
def emit_members(self, typebld):
self.emit_methods(typebld)
self.emit_properties(typebld)
self.emit_classattribs(typebld)
def map_members(self, new_type): pass
interface_module_builder = None
@staticmethod
def define_interface(typename, bases):
for b in bases:
validate_clr_types(b)
if not ClrInterface.interface_module_builder:
name = AssemblyName("interfaces")
access = AssemblyBuilderAccess.Run
assembly_builder = ReflectionUtils.DefineDynamicAssembly(name, access)
ClrInterface.interface_module_builder = assembly_builder.DefineDynamicModule("interfaces")
attrs = TypeAttributes.Public | TypeAttributes.Interface | TypeAttributes.Abstract
return ClrInterface.interface_module_builder.DefineType(typename, attrs, None, bases)
def map_clr_type(self, clr_type):
"""
TODO - Currently "t = clr.GetPythonType(clr.GetClrType(C)); t == C" will be False
for C where C.__metaclass__ is ClrInterface, even though both t and C
represent the same CLR type. This can be fixed by publishing a mapping
between t and C in the IronPython runtime.
"""
pass
def __clrtype__(self):
# CFoo below will use ClrInterface as its metaclass, but the user will not expect CFoo
# to be an interface in this case:
#
# class IFoo(object):
# __metaclass__ = ClrInterface
# class CFoo(IFoo): pass
if not "__metaclass__" in self.__dict__:
return super(ClrInterface, self).__clrtype__()
bases = list(self.__bases__)
bases.remove(object)
bases = tuple(bases)
if False: # Snippets currently does not support creating interfaces
typegen = Snippets.Shared.DefineType(self.get_clr_type_name(), bases, True, False)
typebld = typegen.TypeBuilder
else:
typebld = ClrInterface.define_interface(self.get_clr_type_name(), bases)
clr_type = self.create_type(typebld)
self.map_clr_type(clr_type)
return clr_type
# Note that ClrClass inherits from ClrInterface to satisfy Python requirements of metaclasses.
# A metaclass of a subtype has to be subtype of the metaclass of a base type. As a result,
# if you define a type hierarchy as shown below, it requires ClrClass to be a subtype
# of ClrInterface:
#
# class IFoo(object):
# __metaclass__ = ClrInterface
# class CFoo(IFoo):
# __metaclass__ = ClrClass
class ClrClass(ClrInterface):
"""
Set __metaclass__ in a Python class declaration to specify strong-type
information for the class or its attributes. The Python class
retains its Python attributes, like being able to add or remove methods.
"""
# Holds the FieldInfo for a static CLR field which points to a
# Microsoft.Scripting.Runtime.DynamicOperations corresponding to the current ScriptEngine
dynamic_operations_field = None
def emit_fields(self, typebld):
if hasattr(self, "_clrfields"):
for fldname in self._clrfields:
field_type = self._clrfields[fldname]
validate_clr_types(field_type)
typebld.DefineField(
fldname,
clr.GetClrType(field_type),
FieldAttributes.Public)
def map_fields(self, new_type):
if hasattr(self, "_clrfields"):
for fldname in self._clrfields:
fldinfo = new_type.GetField(fldname)
setattr(self, fldname, ReflectedField(fldinfo))
@staticmethod
def get_dynamic_operations_field():
if ClrClass.dynamic_operations_field:
return ClrClass.dynamic_operations_field
python_context = clr.GetCurrentRuntime().GetLanguage(PythonContext)
dynamic_operations = DynamicOperations(python_context)
typegen = Snippets.Shared.DefineType(
"DynamicOperationsHolder" + str(hash(python_context)),
object,
True,
False)
typebld = typegen.TypeBuilder
typebld.DefineField(
"DynamicOperations",
DynamicOperations,
FieldAttributes.Public | FieldAttributes.Static)
new_type = typebld.CreateType()
ClrClass.dynamic_operations_field = new_type.GetField("DynamicOperations")
ClrClass.dynamic_operations_field.SetValue(None, dynamic_operations)
return ClrClass.dynamic_operations_field
def emit_typed_stub_to_python_method(self, typebld, function_info):
function = function_info.function
"""
Generate a stub method that repushes all the arguments and
dispatches to DynamicOperations.InvokeMember
"""
invoke_member = clr.GetClrType(DynamicOperations).GetMethod(
"InvokeMember",
Array[Type]((object, str, Array[object])))
# Type.GetMethod raises an AmbiguousMatchException if there is a generic and a non-generic method
# (like DynamicOperations.GetMember) with the same name and signature. So we have to do things
# the hard way
get_member_search = [m for m in clr.GetClrType(DynamicOperations).GetMethods() if m.Name == "GetMember" and not m.IsGenericMethod and m.GetParameters().Length == 2]
assert(len(get_member_search) == 1)
get_member = get_member_search[0]
set_member_search = [m for m in clr.GetClrType(DynamicOperations).GetMethods() if m.Name == "SetMember" and not m.IsGenericMethod and m.GetParameters().Length == 3]
assert(len(set_member_search) == 1)
set_member = set_member_search[0]
convert_to = clr.GetClrType(DynamicOperations).GetMethod(
"ConvertTo",
Array[Type]((object, Type)))
get_type_from_handle = clr.GetClrType(Type).GetMethod("GetTypeFromHandle")
attributes = MethodAttributes.Public
if function_info.is_static: attributes |= MethodAttributes.Static
if function.func_name == "__new__":
if function_info.is_static: raise TypeError
method_builder = typebld.DefineConstructor(
attributes,
CallingConventions.HasThis,
function.arg_types)
raise NotImplementedError("Need to call self.baseType ctor passing in self.get_python_type_field()")
else:
method_builder = typebld.DefineMethod(
function.func_name,
attributes,
function.return_type,
function.arg_types)
instance_offset = 0 if function_info.is_static else 1
arg_names = function.func_code.co_varnames
for i in xrange(len(function.arg_types)):
# TODO - set non-trivial ParameterAttributes, default value and custom attributes
p = method_builder.DefineParameter(i + 1, ParameterAttributes.None, arg_names[i + instance_offset])
ilgen = method_builder.GetILGenerator()
args_array = ilgen.DeclareLocal(Array[object])
args_count = len(function.arg_types)
ilgen.Emit(OpCodes.Ldc_I4, args_count)
ilgen.Emit(OpCodes.Newarr, object)
ilgen.Emit(OpCodes.Stloc, args_array)
for i in xrange(args_count):
arg_type = function.arg_types[i]
if clr.GetClrType(arg_type).IsByRef:
raise NotImplementedError("byref params not supported")
ilgen.Emit(OpCodes.Ldloc, args_array)
ilgen.Emit(OpCodes.Ldc_I4, i)
ilgen.Emit(OpCodes.Ldarg, i + int(not function_info.is_static))
ilgen.Emit(OpCodes.Box, arg_type)
ilgen.Emit(OpCodes.Stelem_Ref)
has_return_value = True
if function_info.prop_name_if_prop_get:
ilgen.Emit(OpCodes.Ldsfld, ClrClass.get_dynamic_operations_field())
ilgen.Emit(OpCodes.Ldarg, 0)
ilgen.Emit(OpCodes.Ldstr, function_info.prop_name_if_prop_get)
ilgen.Emit(OpCodes.Callvirt, get_member)
elif function_info.prop_name_if_prop_set:
ilgen.Emit(OpCodes.Ldsfld, ClrClass.get_dynamic_operations_field())
ilgen.Emit(OpCodes.Ldarg, 0)
ilgen.Emit(OpCodes.Ldstr, function_info.prop_name_if_prop_set)
ilgen.Emit(OpCodes.Ldarg, 1)
ilgen.Emit(OpCodes.Callvirt, set_member)
has_return_value = False
else:
ilgen.Emit(OpCodes.Ldsfld, ClrClass.get_dynamic_operations_field())
if function_info.is_static:
raise NotImplementedError("need to load Python class object from a CLR static field")
# ilgen.Emit(OpCodes.Ldsfld, class_object)
else:
ilgen.Emit(OpCodes.Ldarg, 0)
ilgen.Emit(OpCodes.Ldstr, function.func_name)
ilgen.Emit(OpCodes.Ldloc, args_array)
ilgen.Emit(OpCodes.Callvirt, invoke_member)
if has_return_value:
if function.return_type == Void:
ilgen.Emit(OpCodes.Pop)
else:
ret_val = ilgen.DeclareLocal(object)
ilgen.Emit(OpCodes.Stloc, ret_val)
ilgen.Emit(OpCodes.Ldsfld, ClrClass.get_dynamic_operations_field())
ilgen.Emit(OpCodes.Ldloc, ret_val)
ilgen.Emit(OpCodes.Ldtoken, clr.GetClrType(function.return_type))
ilgen.Emit(OpCodes.Call, get_type_from_handle)
ilgen.Emit(OpCodes.Callvirt, convert_to)
ilgen.Emit(OpCodes.Unbox_Any, function.return_type)
ilgen.Emit(OpCodes.Ret)
return method_builder
def emit_method(self, typebld, function_info):
function = function_info.function
if hasattr(function, "DllImportAttributeDecorator"):
dllImportAttributeDecorator = function.DllImportAttributeDecorator
name = function.func_name
dllName = dllImportAttributeDecorator.args[0]
entryName = function.func_name
attributes = MethodAttributes.Public | MethodAttributes.Static | MethodAttributes.PinvokeImpl
callingConvention = CallingConventions.Standard
returnType = function.return_type
returnTypeRequiredCustomModifiers = ()
returnTypeOptionalCustomModifiers = ()
parameterTypes = function.arg_types
parameterTypeRequiredCustomModifiers = None
parameterTypeOptionalCustomModifiers = None
nativeCallConv = CallingConvention.Winapi
nativeCharSet = CharSet.Auto
method_builder = typebld.DefinePInvokeMethod(
name,
dllName,
entryName,
attributes,
callingConvention,
returnType,
returnTypeRequiredCustomModifiers,
returnTypeOptionalCustomModifiers,
parameterTypes,
parameterTypeRequiredCustomModifiers,
parameterTypeOptionalCustomModifiers,
nativeCallConv,
nativeCharSet)
else:
method_builder = self.emit_typed_stub_to_python_method(typebld, function_info)
if hasattr(function, "CustomAttributeBuilders"):
for cab in function.CustomAttributeBuilders:
method_builder.SetCustomAttribute(cab)
return method_builder
def map_pinvoke_methods(self, new_type):
pythonType = clr.GetPythonType(new_type)
for function_info in self.get_typed_methods():
function = function_info.function
if hasattr(function, "DllImportAttributeDecorator"):
# Overwrite the Python function with the pinvoke_method
pinvoke_method = getattr(pythonType, function.func_name)
setattr(self, function.func_name, pinvoke_method)
def emit_python_type_field(self, typebld):
return typebld.DefineField(
"PythonType",
PythonType,
FieldAttributes.Public | FieldAttributes.Static)
def set_python_type_field(self, new_type):
self.PythonType = new_type.GetField("PythonType")
self.PythonType.SetValue(None, self)
def add_wrapper_ctors(self, baseType, typebld):
python_type_field = self.emit_python_type_field(typebld)
for ctor in baseType.GetConstructors():
ctorparams = ctor.GetParameters()
# leave out the PythonType argument
assert(ctorparams[0].ParameterType == clr.GetClrType(PythonType))
ctorparams = ctorparams[1:]
ctorbld = typebld.DefineConstructor(
ctor.Attributes,
ctor.CallingConvention,
tuple([p.ParameterType for p in ctorparams]))
ilgen = ctorbld.GetILGenerator()
ilgen.Emit(OpCodes.Ldarg, 0)
ilgen.Emit(OpCodes.Ldsfld, python_type_field)
for index in xrange(len(ctorparams)):
ilgen.Emit(OpCodes.Ldarg, index + 1)
ilgen.Emit(OpCodes.Call, ctor)
ilgen.Emit(OpCodes.Ret)
def emit_members(self, typebld):
self.emit_fields(typebld)
self.add_wrapper_ctors(self.baseType, typebld)
super(ClrClass, self).emit_members(typebld)
def map_members(self, new_type):
self.map_fields(new_type)
self.map_pinvoke_methods(new_type)
self.set_python_type_field(new_type)
super(ClrClass, self).map_members(new_type)
def __clrtype__(self):
# CDerived below will use ClrClass as its metaclass, but the user may not expect CDerived
# to be a typed .NET class in this case:
#
# class CBase(object):
# __metaclass__ = ClrClass
# class CDerived(CBase): pass
if not "__metaclass__" in self.__dict__:
return super(ClrClass, self).__clrtype__()
# Create a simple Python type first.
self.baseType = super(ClrType, self).__clrtype__()
# We will now subtype it to create a customized class with the
# CLR attributes as defined by the user
typegen = Snippets.Shared.DefineType(self.get_clr_type_name(), self.baseType, True, False)
typebld = typegen.TypeBuilder
return self.create_type(typebld)
def make_cab(attrib_type, *args, **kwds):
clrtype = clr.GetClrType(attrib_type)
argtypes = tuple(map(lambda x:clr.GetClrType(type(x)), args))
ci = clrtype.GetConstructor(argtypes)
props = ([],[])
fields = ([],[])
for kwd in kwds:
pi = clrtype.GetProperty(kwd)
if pi is not None:
props[0].append(pi)
props[1].append(kwds[kwd])
else:
fi = clrtype.GetField(kwd)
if fi is not None:
fields[0].append(fi)
fields[1].append(kwds[kwd])
else:
raise TypeError("No %s Member found on %s" % (kwd, clrtype.Name))
return CustomAttributeBuilder(ci, args,
tuple(props[0]), tuple(props[1]),
tuple(fields[0]), tuple(fields[1]))
def accepts(*args):
"""
TODO - needs to be merged with clr.accepts
"""
validate_clr_types(args, True)
def decorator(function):
function.arg_types = args
return function
return decorator
def returns(return_type = Void):
"""
TODO - needs to be merged with clr.returns
"""
if return_type != Void:
validate_clr_types(return_type)
def decorator(function):
function.return_type = return_type
return function
return decorator
class CustomAttributeDecorator(object):
"""
This represents information about a custom-attribute applied to a type or a method
Note that we cannot use an instance of System.Attribute to capture this information
as it is not possible to go from an instance of System.Attribute to an instance
of System.Reflection.Emit.CustomAttributeBuilder as the latter needs to know
how to represent information in metadata to later *recreate* a similar instance of
System.Attribute.
Also note that once a CustomAttributeBuilder is created, it is not possible to
query it. Hence, we need to store the arguments required to store the
CustomAttributeBuilder so that pseudo-custom-attributes can get to the information.
"""
def __init__(self, attrib_type, *args, **kwargs):
self.attrib_type = attrib_type
self.args = args
self.kwargs = kwargs
def __call__(self, function):
if self.attrib_type == DllImportAttribute:
function.DllImportAttributeDecorator = self
else:
if not hasattr(function, "CustomAttributeBuilders"):
function.CustomAttributeBuilders = []
function.CustomAttributeBuilders.append(self.GetBuilder())
return function
def GetBuilder(self):
assert not self.attrib_type in [DllImportAttribute]
return make_cab(self.attrib_type, *self.args, **self.kwargs)
def attribute(attrib_type):
"""
This decorator is used to specify a CustomAttribute for a type or method.
"""
def make_decorator(*args, **kwargs):
return CustomAttributeDecorator(attrib_type, *args, **kwargs)
return make_decorator
def propagate_attributes(old_function, new_function):
"""
Use this if you replace a function in a type with ClrInterface or ClrClass as the metaclass.
This will typically be needed if you are defining a decorator which wraps functions with
new functions, and want it to work in conjunction with clrtype
"""
if hasattr(old_function, "return_type"):
new_function.func_name = old_function.func_name
new_function.return_type = old_function.return_type
new_function.arg_types = old_function.arg_types
if hasattr(old_function, "CustomAttributeBuilders"):
new_function.CustomAttributeBuilders = old_function.CustomAttributeBuilders
if hasattr(old_function, "CustomAttributeBuilders"):
new_function.DllImportAttributeDecorator = old_function.DllImportAttributeDecorator
"""A generic class to build line-oriented command interpreters.
Interpreters constructed with this class obey the following conventions:
1. End of file on input is processed as the command 'EOF'.
2. A command is parsed out of each line by collecting the prefix composed
of characters in the identchars member.
3. A command `foo' is dispatched to a method 'do_foo()'; the do_ method
is passed a single argument consisting of the remainder of the line.
4. Typing an empty line repeats the last command. (Actually, it calls the
method `emptyline', which may be overridden in a subclass.)
5. There is a predefined `help' method. Given an argument `topic', it
calls the command `help_topic'. With no arguments, it lists all topics
with defined help_ functions, broken into up to three topics; documented
commands, miscellaneous help topics, and undocumented commands.
6. The command '?' is a synonym for `help'. The command '!' is a synonym
for `shell', if a do_shell method exists.
7. If completion is enabled, completing commands will be done automatically,
and completing of commands args is done by calling complete_foo() with
arguments text, line, begidx, endidx. text is string we are matching
against, all returned matches must begin with it. line is the current
input line (lstripped), begidx and endidx are the beginning and end
indexes of the text being matched, which could be used to provide
different completion depending upon which position the argument is in.
The `default' method may be overridden to intercept commands for which there
is no do_ method.
The `completedefault' method may be overridden to intercept completions for
commands that have no complete_ method.
The data member `self.ruler' sets the character used to draw separator lines
in the help messages. If empty, no ruler line is drawn. It defaults to "=".
If the value of `self.intro' is nonempty when the cmdloop method is called,
it is printed out on interpreter startup. This value may be overridden
via an optional argument to the cmdloop() method.
The data members `self.doc_header', `self.misc_header', and
`self.undoc_header' set the headers used for the help function's
listings of documented functions, miscellaneous topics, and undocumented
functions respectively.
These interpreters use raw_input; thus, if the readline module is loaded,
they automatically support Emacs-like command history and editing features.
"""
import string
__all__ = ["Cmd"]
PROMPT = '(Cmd) '
IDENTCHARS = string.ascii_letters + string.digits + '_'
class Cmd:
"""A simple framework for writing line-oriented command interpreters.
These are often useful for test harnesses, administrative tools, and
prototypes that will later be wrapped in a more sophisticated interface.
A Cmd instance or subclass instance is a line-oriented interpreter
framework. There is no good reason to instantiate Cmd itself; rather,
it's useful as a superclass of an interpreter class you define yourself
in order to inherit Cmd's methods and encapsulate action methods.
"""
prompt = PROMPT
identchars = IDENTCHARS
ruler = '='
lastcmd = ''
intro = None
doc_leader = ""
doc_header = "Documented commands (type help <topic>):"
misc_header = "Miscellaneous help topics:"
undoc_header = "Undocumented commands:"
nohelp = "*** No help on %s"
use_rawinput = 1
def __init__(self, completekey='tab', stdin=None, stdout=None):
"""Instantiate a line-oriented interpreter framework.
The optional argument 'completekey' is the readline name of a
completion key; it defaults to the Tab key. If completekey is
not None and the readline module is available, command completion
is done automatically. The optional arguments stdin and stdout
specify alternate input and output file objects; if not specified,
sys.stdin and sys.stdout are used.
"""
import sys
if stdin is not None:
self.stdin = stdin
else:
self.stdin = sys.stdin
if stdout is not None:
self.stdout = stdout
else:
self.stdout = sys.stdout
self.cmdqueue = []
self.completekey = completekey
def cmdloop(self, intro=None):
"""Repeatedly issue a prompt, accept input, parse an initial prefix
off the received input, and dispatch to action methods, passing them
the remainder of the line as argument.
"""
self.preloop()
if self.use_rawinput and self.completekey:
try:
import readline
self.old_completer = readline.get_completer()
readline.set_completer(self.complete)
readline.parse_and_bind(self.completekey+": complete")
except ImportError:
pass
try:
if intro is not None:
self.intro = intro
if self.intro:
self.stdout.write(str(self.intro)+"\n")
stop = None
while not stop:
if self.cmdqueue:
line = self.cmdqueue.pop(0)
else:
if self.use_rawinput:
try:
line = raw_input(self.prompt)
except EOFError:
line = 'EOF'
else:
self.stdout.write(self.prompt)
self.stdout.flush()
line = self.stdin.readline()
if not len(line):
line = 'EOF'
else:
line = line.rstrip('\r\n')
line = self.precmd(line)
stop = self.onecmd(line)
stop = self.postcmd(stop, line)
self.postloop()
finally:
if self.use_rawinput and self.completekey:
try:
import readline
readline.set_completer(self.old_completer)
except ImportError:
pass
def precmd(self, line):
"""Hook method executed just before the command line is
interpreted, but after the input prompt is generated and issued.
"""
return line
def postcmd(self, stop, line):
"""Hook method executed just after a command dispatch is finished."""
return stop
def preloop(self):
"""Hook method executed once when the cmdloop() method is called."""
pass
def postloop(self):
"""Hook method executed once when the cmdloop() method is about to
return.
"""
pass
def parseline(self, line):
"""Parse the line into a command name and a string containing
the arguments. Returns a tuple containing (command, args, line).
'command' and 'args' may be None if the line couldn't be parsed.
"""
line = line.strip()
if not line:
return None, None, line
elif line[0] == '?':
line = 'help ' + line[1:]
elif line[0] == '!':
if hasattr(self, 'do_shell'):
line = 'shell ' + line[1:]
else:
return None, None, line
i, n = 0, len(line)
while i < n and line[i] in self.identchars: i = i+1
cmd, arg = line[:i], line[i:].strip()
return cmd, arg, line
def onecmd(self, line):
"""Interpret the argument as though it had been typed in response
to the prompt.
This may be overridden, but should not normally need to be;
see the precmd() and postcmd() methods for useful execution hooks.
The return value is a flag indicating whether interpretation of
commands by the interpreter should stop.
"""
cmd, arg, line = self.parseline(line)
if not line:
return self.emptyline()
if cmd is None:
return self.default(line)
self.lastcmd = line
if line == 'EOF' :
self.lastcmd = ''
if cmd == '':
return self.default(line)
else:
try:
func = getattr(self, 'do_' + cmd)
except AttributeError:
return self.default(line)
return func(arg)
def emptyline(self):
"""Called when an empty line is entered in response to the prompt.
If this method is not overridden, it repeats the last nonempty
command entered.
"""
if self.lastcmd:
return self.onecmd(self.lastcmd)
def default(self, line):
"""Called on an input line when the command prefix is not recognized.
If this method is not overridden, it prints an error message and
returns.
"""
self.stdout.write('*** Unknown syntax: %s\n'%line)
def completedefault(self, *ignored):
"""Method called to complete an input line when no command-specific
complete_*() method is available.
By default, it returns an empty list.
"""
return []
def completenames(self, text, *ignored):
dotext = 'do_'+text
return [a[3:] for a in self.get_names() if a.startswith(dotext)]
def complete(self, text, state):
"""Return the next possible completion for 'text'.
If a command has not been entered, then complete against command list.
Otherwise try to call complete_<command> to get list of completions.
"""
if state == 0:
import readline
origline = readline.get_line_buffer()
line = origline.lstrip()
stripped = len(origline) - len(line)
begidx = readline.get_begidx() - stripped
endidx = readline.get_endidx() - stripped
if begidx>0:
cmd, args, foo = self.parseline(line)
if cmd == '':
compfunc = self.completedefault
else:
try:
compfunc = getattr(self, 'complete_' + cmd)
except AttributeError:
compfunc = self.completedefault
else:
compfunc = self.completenames
self.completion_matches = compfunc(text, line, begidx, endidx)
try:
return self.completion_matches[state]
except IndexError:
return None
def get_names(self):
# This method used to pull in base class attributes
# at a time dir() didn't do it yet.
return dir(self.__class__)
def complete_help(self, *args):
commands = set(self.completenames(*args))
topics = set(a[5:] for a in self.get_names()
if a.startswith('help_' + args[0]))
return list(commands | topics)
def do_help(self, arg):
'List available commands with "help" or detailed help with "help cmd".'
if arg:
# XXX check arg syntax
try:
func = getattr(self, 'help_' + arg)
except AttributeError:
try:
doc=getattr(self, 'do_' + arg).__doc__
if doc:
self.stdout.write("%s\n"%str(doc))
return
except AttributeError:
pass
self.stdout.write("%s\n"%str(self.nohelp % (arg,)))
return
func()
else:
names = self.get_names()
cmds_doc = []
cmds_undoc = []
help = {}
for name in names:
if name[:5] == 'help_':
help[name[5:]]=1
names.sort()
# There can be duplicates if routines overridden
prevname = ''
for name in names:
if name[:3] == 'do_':
if name == prevname:
continue
prevname = name
cmd=name[3:]
if cmd in help:
cmds_doc.append(cmd)
del help[cmd]
elif getattr(self, name).__doc__:
cmds_doc.append(cmd)
else:
cmds_undoc.append(cmd)
self.stdout.write("%s\n"%str(self.doc_leader))
self.print_topics(self.doc_header, cmds_doc, 15,80)
self.print_topics(self.misc_header, help.keys(),15,80)
self.print_topics(self.undoc_header, cmds_undoc, 15,80)
def print_topics(self, header, cmds, cmdlen, maxcol):
if cmds:
self.stdout.write("%s\n"%str(header))
if self.ruler:
self.stdout.write("%s\n"%str(self.ruler * len(header)))
self.columnize(cmds, maxcol-1)
self.stdout.write("\n")
def columnize(self, list, displaywidth=80):
"""Display a list of strings as a compact set of columns.
Each column is only as wide as necessary.
Columns are separated by two spaces (one was not legible enough).
"""
if not list:
self.stdout.write("<empty>\n")
return
nonstrings = [i for i in range(len(list))
if not isinstance(list[i], str)]
if nonstrings:
raise TypeError, ("list[i] not a string for i in %s" %
", ".join(map(str, nonstrings)))
size = len(list)
if size == 1:
self.stdout.write('%s\n'%str(list[0]))
return
# Try every row count from 1 upwards
for nrows in range(1, len(list)):
ncols = (size+nrows-1) // nrows
colwidths = []
totwidth = -2
for col in range(ncols):
colwidth = 0
for row in range(nrows):
i = row + nrows*col
if i >= size:
break
x = list[i]
colwidth = max(colwidth, len(x))
colwidths.append(colwidth)
totwidth += colwidth + 2
if totwidth > displaywidth:
break
if totwidth <= displaywidth:
break
else:
nrows = len(list)
ncols = 1
colwidths = [0]
for row in range(nrows):
texts = []
for col in range(ncols):
i = row + nrows*col
if i >= size:
x = ""
else:
x = list[i]
texts.append(x)
while texts and not texts[-1]:
del texts[-1]
for col in range(len(texts)):
texts[col] = texts[col].ljust(colwidths[col])
self.stdout.write("%s\n"%str(" ".join(texts)))
"""Utilities needed to emulate Python's interactive interpreter.
"""
# Inspired by similar code by Jeff Epler and Fredrik Lundh.
import sys
import traceback
from codeop import CommandCompiler, compile_command
__all__ = ["InteractiveInterpreter", "InteractiveConsole", "interact",
"compile_command"]
def softspace(file, newvalue):
oldvalue = 0
try:
oldvalue = file.softspace
except AttributeError:
pass
try:
file.softspace = newvalue
except (AttributeError, TypeError):
# "attribute-less object" or "read-only attributes"
pass
return oldvalue
class InteractiveInterpreter:
"""Base class for InteractiveConsole.
This class deals with parsing and interpreter state (the user's
namespace); it doesn't deal with input buffering or prompting or
input file naming (the filename is always passed in explicitly).
"""
def __init__(self, locals=None):
"""Constructor.
The optional 'locals' argument specifies the dictionary in
which code will be executed; it defaults to a newly created
dictionary with key "__name__" set to "__console__" and key
"__doc__" set to None.
"""
if locals is None:
locals = {"__name__": "__console__", "__doc__": None}
self.locals = locals
self.compile = CommandCompiler()
def runsource(self, source, filename="<input>", symbol="single"):
"""Compile and run some source in the interpreter.
Arguments are as for compile_command().
One several things can happen:
1) The input is incorrect; compile_command() raised an
exception (SyntaxError or OverflowError). A syntax traceback
will be printed by calling the showsyntaxerror() method.
2) The input is incomplete, and more input is required;
compile_command() returned None. Nothing happens.
3) The input is complete; compile_command() returned a code
object. The code is executed by calling self.runcode() (which
also handles run-time exceptions, except for SystemExit).
The return value is True in case 2, False in the other cases (unless
an exception is raised). The return value can be used to
decide whether to use sys.ps1 or sys.ps2 to prompt the next
line.
"""
try:
code = self.compile(source, filename, symbol)
except (OverflowError, SyntaxError, ValueError):
# Case 1
self.showsyntaxerror(filename)
return False
if code is None:
# Case 2
return True
# Case 3
self.runcode(code)
return False
def runcode(self, code):
"""Execute a code object.
When an exception occurs, self.showtraceback() is called to
display a traceback. All exceptions are caught except
SystemExit, which is reraised.
A note about KeyboardInterrupt: this exception may occur
elsewhere in this code, and may not always be caught. The
caller should be prepared to deal with it.
"""
try:
exec code in self.locals
except SystemExit:
raise
except:
self.showtraceback()
else:
if softspace(sys.stdout, 0):
print
def showsyntaxerror(self, filename=None):
"""Display the syntax error that just occurred.
This doesn't display a stack trace because there isn't one.
If a filename is given, it is stuffed in the exception instead
of what was there before (because Python's parser always uses
"<string>" when reading from a string).
The output is written by self.write(), below.
"""
type, value, sys.last_traceback = sys.exc_info()
sys.last_type = type
sys.last_value = value
if filename and type is SyntaxError:
# Work hard to stuff the correct filename in the exception
try:
msg, (dummy_filename, lineno, offset, line) = value
except:
# Not the format we expect; leave it alone
pass
else:
# Stuff in the right filename
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
list = traceback.format_exception_only(type, value)
map(self.write, list)
def showtraceback(self):
"""Display the exception that just occurred.
We remove the first stack item because it is our own code.
The output is written by self.write(), below.
"""
try:
type, value, tb = sys.exc_info()
sys.last_type = type
sys.last_value = value
sys.last_traceback = tb
tblist = traceback.extract_tb(tb)
del tblist[:1]
list = traceback.format_list(tblist)
if list:
list.insert(0, "Traceback (most recent call last):\n")
list[len(list):] = traceback.format_exception_only(type, value)
finally:
tblist = tb = None
map(self.write, list)
def write(self, data):
"""Write a string.
The base implementation writes to sys.stderr; a subclass may
replace this with a different implementation.
"""
sys.stderr.write(data)
class InteractiveConsole(InteractiveInterpreter):
"""Closely emulate the behavior of the interactive Python interpreter.
This class builds on InteractiveInterpreter and adds prompting
using the familiar sys.ps1 and sys.ps2, and input buffering.
"""
def __init__(self, locals=None, filename="<console>"):
"""Constructor.
The optional locals argument will be passed to the
InteractiveInterpreter base class.
The optional filename argument should specify the (file)name
of the input stream; it will show up in tracebacks.
"""
InteractiveInterpreter.__init__(self, locals)
self.filename = filename
self.resetbuffer()
def resetbuffer(self):
"""Reset the input buffer."""
self.buffer = []
def interact(self, banner=None):
"""Closely emulate the interactive Python console.
The optional banner argument specify the banner to print
before the first interaction; by default it prints a banner
similar to the one printed by the real Python interpreter,
followed by the current class name in parentheses (so as not
to confuse this with the real interpreter -- since it's so
close!).
"""
try:
sys.ps1
except AttributeError:
sys.ps1 = ">>> "
try:
sys.ps2
except AttributeError:
sys.ps2 = "... "
cprt = 'Type "help", "copyright", "credits" or "license" for more information.'
if banner is None:
self.write("Python %s on %s\n%s\n(%s)\n" %
(sys.version, sys.platform, cprt,
self.__class__.__name__))
else:
self.write("%s\n" % str(banner))
more = 0
while 1:
try:
if more:
prompt = sys.ps2
else:
prompt = sys.ps1
try:
line = self.raw_input(prompt)
# Can be None if sys.stdin was redefined
encoding = getattr(sys.stdin, "encoding", None)
if encoding and not isinstance(line, unicode):
line = line.decode(encoding)
except EOFError:
self.write("\n")
break
else:
more = self.push(line)
except KeyboardInterrupt:
self.write("\nKeyboardInterrupt\n")
self.resetbuffer()
more = 0
def push(self, line):
"""Push a line to the interpreter.
The line should not have a trailing newline; it may have
internal newlines. The line is appended to a buffer and the
interpreter's runsource() method is called with the
concatenated contents of the buffer as source. If this
indicates that the command was executed or invalid, the buffer
is reset; otherwise, the command is incomplete, and the buffer
is left as it was after the line was appended. The return
value is 1 if more input is required, 0 if the line was dealt
with in some way (this is the same as runsource()).
"""
self.buffer.append(line)
source = "\n".join(self.buffer)
more = self.runsource(source, self.filename)
if not more:
self.resetbuffer()
return more
def raw_input(self, prompt=""):
"""Write a prompt and read a line.
The returned line does not include the trailing newline.
When the user enters the EOF key sequence, EOFError is raised.
The base implementation uses the built-in function
raw_input(); a subclass may replace this with a different
implementation.
"""
return raw_input(prompt)
def interact(banner=None, readfunc=None, local=None):
"""Closely emulate the interactive Python interpreter.
This is a backwards compatible interface to the InteractiveConsole
class. When readfunc is not specified, it attempts to import the
readline module to enable GNU readline if it is available.
Arguments (all optional, all default to None):
banner -- passed to InteractiveConsole.interact()
readfunc -- if not None, replaces InteractiveConsole.raw_input()
local -- passed to InteractiveInterpreter.__init__()
"""
console = InteractiveConsole(local)
if readfunc is not None:
console.raw_input = readfunc
else:
try:
import readline
except ImportError:
pass
console.interact(banner)
if __name__ == "__main__":
interact()
""" codecs -- Python Codec Registry, API and helpers.
Written by Marc-Andre Lemburg ([email protected]).
(c) Copyright CNRI, All Rights Reserved. NO WARRANTY.
"""#"
import __builtin__, sys
### Registry and builtin stateless codec functions
try:
from _codecs import *
except ImportError, why:
raise SystemError('Failed to load the builtin codecs: %s' % why)
__all__ = ["register", "lookup", "open", "EncodedFile", "BOM", "BOM_BE",
"BOM_LE", "BOM32_BE", "BOM32_LE", "BOM64_BE", "BOM64_LE",
"BOM_UTF8", "BOM_UTF16", "BOM_UTF16_LE", "BOM_UTF16_BE",
"BOM_UTF32", "BOM_UTF32_LE", "BOM_UTF32_BE",
"CodecInfo", "Codec", "IncrementalEncoder", "IncrementalDecoder",
"StreamReader", "StreamWriter",
"StreamReaderWriter", "StreamRecoder",
"getencoder", "getdecoder", "getincrementalencoder",
"getincrementaldecoder", "getreader", "getwriter",
"encode", "decode", "iterencode", "iterdecode",
"strict_errors", "ignore_errors", "replace_errors",
"xmlcharrefreplace_errors", "backslashreplace_errors",
"register_error", "lookup_error"]
### Constants
#
# Byte Order Mark (BOM = ZERO WIDTH NO-BREAK SPACE = U+FEFF)
# and its possible byte string values
# for UTF8/UTF16/UTF32 output and little/big endian machines
#
# UTF-8
BOM_UTF8 = '\xef\xbb\xbf'
# UTF-16, little endian
BOM_LE = BOM_UTF16_LE = '\xff\xfe'
# UTF-16, big endian
BOM_BE = BOM_UTF16_BE = '\xfe\xff'
# UTF-32, little endian
BOM_UTF32_LE = '\xff\xfe\x00\x00'
# UTF-32, big endian
BOM_UTF32_BE = '\x00\x00\xfe\xff'
if sys.byteorder == 'little':
# UTF-16, native endianness
BOM = BOM_UTF16 = BOM_UTF16_LE
# UTF-32, native endianness
BOM_UTF32 = BOM_UTF32_LE
else:
# UTF-16, native endianness
BOM = BOM_UTF16 = BOM_UTF16_BE
# UTF-32, native endianness
BOM_UTF32 = BOM_UTF32_BE
# Old broken names (don't use in new code)
BOM32_LE = BOM_UTF16_LE
BOM32_BE = BOM_UTF16_BE
BOM64_LE = BOM_UTF32_LE
BOM64_BE = BOM_UTF32_BE
### Codec base classes (defining the API)
class CodecInfo(tuple):
"""Codec details when looking up the codec registry"""
# Private API to allow Python to blacklist the known non-Unicode
# codecs in the standard library. A more general mechanism to
# reliably distinguish test encodings from other codecs will hopefully
# be defined for Python 3.5
#
# See http://bugs.python.org/issue19619
_is_text_encoding = True # Assume codecs are text encodings by default
def __new__(cls, encode, decode, streamreader=None, streamwriter=None,
incrementalencoder=None, incrementaldecoder=None, name=None,
_is_text_encoding=None):
self = tuple.__new__(cls, (encode, decode, streamreader, streamwriter))
self.name = name
self.encode = encode
self.decode = decode
self.incrementalencoder = incrementalencoder
self.incrementaldecoder = incrementaldecoder
self.streamwriter = streamwriter
self.streamreader = streamreader
if _is_text_encoding is not None:
self._is_text_encoding = _is_text_encoding
return self
def __repr__(self):
return "<%s.%s object for encoding %s at 0x%x>" % (self.__class__.__module__, self.__class__.__name__, self.name, id(self))
class Codec:
""" Defines the interface for stateless encoders/decoders.
The .encode()/.decode() methods may use different error
handling schemes by providing the errors argument. These
string values are predefined:
'strict' - raise a ValueError error (or a subclass)
'ignore' - ignore the character and continue with the next
'replace' - replace with a suitable replacement character;
Python will use the official U+FFFD REPLACEMENT
CHARACTER for the builtin Unicode codecs on
decoding and '?' on encoding.
'xmlcharrefreplace' - Replace with the appropriate XML
character reference (only for encoding).
'backslashreplace' - Replace with backslashed escape sequences
(only for encoding).
The set of allowed values can be extended via register_error.
"""
def encode(self, input, errors='strict'):
""" Encodes the object input and returns a tuple (output
object, length consumed).
errors defines the error handling to apply. It defaults to
'strict' handling.
The method may not store state in the Codec instance. Use
StreamWriter for codecs which have to keep state in order to
make encoding efficient.
The encoder must be able to handle zero length input and
return an empty object of the output object type in this
situation.
"""
raise NotImplementedError
def decode(self, input, errors='strict'):
""" Decodes the object input and returns a tuple (output
object, length consumed).
input must be an object which provides the bf_getreadbuf
buffer slot. Python strings, buffer objects and memory
mapped files are examples of objects providing this slot.
errors defines the error handling to apply. It defaults to
'strict' handling.
The method may not store state in the Codec instance. Use
StreamReader for codecs which have to keep state in order to
make decoding efficient.
The decoder must be able to handle zero length input and
return an empty object of the output object type in this
situation.
"""
raise NotImplementedError
class IncrementalEncoder(object):
"""
An IncrementalEncoder encodes an input in multiple steps. The input can be
passed piece by piece to the encode() method. The IncrementalEncoder remembers
the state of the Encoding process between calls to encode().
"""
def __init__(self, errors='strict'):
"""
Creates an IncrementalEncoder instance.
The IncrementalEncoder may use different error handling schemes by
providing the errors keyword argument. See the module docstring
for a list of possible values.
"""
self.errors = errors
self.buffer = ""
def encode(self, input, final=False):
"""
Encodes input and returns the resulting object.
"""
raise NotImplementedError
def reset(self):
"""
Resets the encoder to the initial state.
"""
def getstate(self):
"""
Return the current state of the encoder.
"""
return 0
def setstate(self, state):
"""
Set the current state of the encoder. state must have been
returned by getstate().
"""
class BufferedIncrementalEncoder(IncrementalEncoder):
"""
This subclass of IncrementalEncoder can be used as the baseclass for an
incremental encoder if the encoder must keep some of the output in a
buffer between calls to encode().
"""
def __init__(self, errors='strict'):
IncrementalEncoder.__init__(self, errors)
self.buffer = "" # unencoded input that is kept between calls to encode()
def _buffer_encode(self, input, errors, final):
# Overwrite this method in subclasses: It must encode input
# and return an (output, length consumed) tuple
raise NotImplementedError
def encode(self, input, final=False):
# encode input (taking the buffer into account)
data = self.buffer + input
(result, consumed) = self._buffer_encode(data, self.errors, final)
# keep unencoded input until the next call
self.buffer = data[consumed:]
return result
def reset(self):
IncrementalEncoder.reset(self)
self.buffer = ""
def getstate(self):
return self.buffer or 0
def setstate(self, state):
self.buffer = state or ""
class IncrementalDecoder(object):
"""
An IncrementalDecoder decodes an input in multiple steps. The input can be
passed piece by piece to the decode() method. The IncrementalDecoder
remembers the state of the decoding process between calls to decode().
"""
def __init__(self, errors='strict'):
"""
Creates an IncrementalDecoder instance.
The IncrementalDecoder may use different error handling schemes by
providing the errors keyword argument. See the module docstring
for a list of possible values.
"""
self.errors = errors
def decode(self, input, final=False):
"""
Decodes input and returns the resulting object.
"""
raise NotImplementedError
def reset(self):
"""
Resets the decoder to the initial state.
"""
def getstate(self):
"""
Return the current state of the decoder.
This must be a (buffered_input, additional_state_info) tuple.
buffered_input must be a bytes object containing bytes that
were passed to decode() that have not yet been converted.
additional_state_info must be a non-negative integer
representing the state of the decoder WITHOUT yet having
processed the contents of buffered_input. In the initial state
and after reset(), getstate() must return (b"", 0).
"""
return (b"", 0)
def setstate(self, state):
"""
Set the current state of the decoder.
state must have been returned by getstate(). The effect of
setstate((b"", 0)) must be equivalent to reset().
"""
class BufferedIncrementalDecoder(IncrementalDecoder):
"""
This subclass of IncrementalDecoder can be used as the baseclass for an
incremental decoder if the decoder must be able to handle incomplete byte
sequences.
"""
def __init__(self, errors='strict'):
IncrementalDecoder.__init__(self, errors)
self.buffer = "" # undecoded input that is kept between calls to decode()
def _buffer_decode(self, input, errors, final):
# Overwrite this method in subclasses: It must decode input
# and return an (output, length consumed) tuple
raise NotImplementedError
def decode(self, input, final=False):
# decode input (taking the buffer into account)
data = self.buffer + input
(result, consumed) = self._buffer_decode(data, self.errors, final)
# keep undecoded input until the next call
self.buffer = data[consumed:]
return result
def reset(self):
IncrementalDecoder.reset(self)
self.buffer = ""
def getstate(self):
# additional state info is always 0
return (self.buffer, 0)
def setstate(self, state):
# ignore additional state info
self.buffer = state[0]
#
# The StreamWriter and StreamReader class provide generic working
# interfaces which can be used to implement new encoding submodules
# very easily. See encodings/utf_8.py for an example on how this is
# done.
#
class StreamWriter(Codec):
def __init__(self, stream, errors='strict'):
""" Creates a StreamWriter instance.
stream must be a file-like object open for writing
(binary) data.
The StreamWriter may use different error handling
schemes by providing the errors keyword argument. These
parameters are predefined:
'strict' - raise a ValueError (or a subclass)
'ignore' - ignore the character and continue with the next
'replace'- replace with a suitable replacement character
'xmlcharrefreplace' - Replace with the appropriate XML
character reference.
'backslashreplace' - Replace with backslashed escape
sequences (only for encoding).
The set of allowed parameter values can be extended via
register_error.
"""
self.stream = stream
self.errors = errors
def write(self, object):
""" Writes the object's contents encoded to self.stream.
"""
data, consumed = self.encode(object, self.errors)
self.stream.write(data)
def writelines(self, list):
""" Writes the concatenated list of strings to the stream
using .write().
"""
self.write(''.join(list))
def reset(self):
""" Flushes and resets the codec buffers used for keeping state.
Calling this method should ensure that the data on the
output is put into a clean state, that allows appending
of new fresh data without having to rescan the whole
stream to recover state.
"""
pass
def seek(self, offset, whence=0):
self.stream.seek(offset, whence)
if whence == 0 and offset == 0:
self.reset()
def __getattr__(self, name,
getattr=getattr):
""" Inherit all other methods from the underlying stream.
"""
return getattr(self.stream, name)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.stream.close()
###
class StreamReader(Codec):
def __init__(self, stream, errors='strict'):
""" Creates a StreamReader instance.
stream must be a file-like object open for reading
(binary) data.
The StreamReader may use different error handling
schemes by providing the errors keyword argument. These
parameters are predefined:
'strict' - raise a ValueError (or a subclass)
'ignore' - ignore the character and continue with the next
'replace'- replace with a suitable replacement character;
The set of allowed parameter values can be extended via
register_error.
"""
self.stream = stream
self.errors = errors
self.bytebuffer = ""
# For str->str decoding this will stay a str
# For str->unicode decoding the first read will promote it to unicode
self.charbuffer = ""
self.linebuffer = None
def decode(self, input, errors='strict'):
raise NotImplementedError
def read(self, size=-1, chars=-1, firstline=False):
""" Decodes data from the stream self.stream and returns the
resulting object.
chars indicates the number of characters to read from the
stream. read() will never return more than chars
characters, but it might return less, if there are not enough
characters available.
size indicates the approximate maximum number of bytes to
read from the stream for decoding purposes. The decoder
can modify this setting as appropriate. The default value
-1 indicates to read and decode as much as possible. size
is intended to prevent having to decode huge files in one
step.
If firstline is true, and a UnicodeDecodeError happens
after the first line terminator in the input only the first line
will be returned, the rest of the input will be kept until the
next call to read().
The method should use a greedy read strategy meaning that
it should read as much data as is allowed within the
definition of the encoding and the given size, e.g. if
optional encoding endings or state markers are available
on the stream, these should be read too.
"""
# If we have lines cached, first merge them back into characters
if self.linebuffer:
self.charbuffer = "".join(self.linebuffer)
self.linebuffer = None
if chars < 0:
# For compatibility with other read() methods that take a
# single argument
chars = size
# read until we get the required number of characters (if available)
while True:
# can the request be satisfied from the character buffer?
if chars >= 0:
if len(self.charbuffer) >= chars:
break
# we need more data
if size < 0:
newdata = self.stream.read()
else:
newdata = self.stream.read(size)
# decode bytes (those remaining from the last call included)
data = self.bytebuffer + newdata
try:
newchars, decodedbytes = self.decode(data, self.errors)
except UnicodeDecodeError, exc:
if firstline:
newchars, decodedbytes = self.decode(data[:exc.start], self.errors)
lines = newchars.splitlines(True)
if len(lines)<=1:
raise
else:
raise
# keep undecoded bytes until the next call
self.bytebuffer = data[decodedbytes:]
# put new characters in the character buffer
self.charbuffer += newchars
# there was no data available
if not newdata:
break
if chars < 0:
# Return everything we've got
result = self.charbuffer
self.charbuffer = ""
else:
# Return the first chars characters
result = self.charbuffer[:chars]
self.charbuffer = self.charbuffer[chars:]
return result
def readline(self, size=None, keepends=True):
""" Read one line from the input stream and return the
decoded data.
size, if given, is passed as size argument to the
read() method.
"""
# If we have lines cached from an earlier read, return
# them unconditionally
if self.linebuffer:
line = self.linebuffer[0]
del self.linebuffer[0]
if len(self.linebuffer) == 1:
# revert to charbuffer mode; we might need more data
# next time
self.charbuffer = self.linebuffer[0]
self.linebuffer = None
if not keepends:
line = line.splitlines(False)[0]
return line
readsize = size or 72
line = ""
# If size is given, we call read() only once
while True:
data = self.read(readsize, firstline=True)
if data:
# If we're at a "\r" read one extra character (which might
# be a "\n") to get a proper line ending. If the stream is
# temporarily exhausted we return the wrong line ending.
if data.endswith("\r"):
data += self.read(size=1, chars=1)
line += data
lines = line.splitlines(True)
if lines:
if len(lines) > 1:
# More than one line result; the first line is a full line
# to return
line = lines[0]
del lines[0]
if len(lines) > 1:
# cache the remaining lines
lines[-1] += self.charbuffer
self.linebuffer = lines
self.charbuffer = None
else:
# only one remaining line, put it back into charbuffer
self.charbuffer = lines[0] + self.charbuffer
if not keepends:
line = line.splitlines(False)[0]
break
line0withend = lines[0]
line0withoutend = lines[0].splitlines(False)[0]
if line0withend != line0withoutend: # We really have a line end
# Put the rest back together and keep it until the next call
self.charbuffer = "".join(lines[1:]) + self.charbuffer
if keepends:
line = line0withend
else:
line = line0withoutend
break
# we didn't get anything or this was our only try
if not data or size is not None:
if line and not keepends:
line = line.splitlines(False)[0]
break
if readsize<8000:
readsize *= 2
return line
def readlines(self, sizehint=None, keepends=True):
""" Read all lines available on the input stream
and return them as list of lines.
Line breaks are implemented using the codec's decoder
method and are included in the list entries.
sizehint, if given, is ignored since there is no efficient
way to finding the true end-of-line.
"""
data = self.read()
return data.splitlines(keepends)
def reset(self):
""" Resets the codec buffers used for keeping state.
Note that no stream repositioning should take place.
This method is primarily intended to be able to recover
from decoding errors.
"""
self.bytebuffer = ""
self.charbuffer = u""
self.linebuffer = None
def seek(self, offset, whence=0):
""" Set the input stream's current position.
Resets the codec buffers used for keeping state.
"""
self.stream.seek(offset, whence)
self.reset()
def next(self):
""" Return the next decoded line from the input stream."""
line = self.readline()
if line:
return line
raise StopIteration
def __iter__(self):
return self
def __getattr__(self, name,
getattr=getattr):
""" Inherit all other methods from the underlying stream.
"""
return getattr(self.stream, name)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.stream.close()
###
class StreamReaderWriter:
""" StreamReaderWriter instances allow wrapping streams which
work in both read and write modes.
The design is such that one can use the factory functions
returned by the codec.lookup() function to construct the
instance.
"""
# Optional attributes set by the file wrappers below
encoding = 'unknown'
def __init__(self, stream, Reader, Writer, errors='strict'):
""" Creates a StreamReaderWriter instance.
stream must be a Stream-like object.
Reader, Writer must be factory functions or classes
providing the StreamReader, StreamWriter interface resp.
Error handling is done in the same way as defined for the
StreamWriter/Readers.
"""
self.stream = stream
self.reader = Reader(stream, errors)
self.writer = Writer(stream, errors)
self.errors = errors
def read(self, size=-1):
return self.reader.read(size)
def readline(self, size=None):
return self.reader.readline(size)
def readlines(self, sizehint=None):
return self.reader.readlines(sizehint)
def next(self):
""" Return the next decoded line from the input stream."""
return self.reader.next()
def __iter__(self):
return self
def write(self, data):
return self.writer.write(data)
def writelines(self, list):
return self.writer.writelines(list)
def reset(self):
self.reader.reset()
self.writer.reset()
def seek(self, offset, whence=0):
self.stream.seek(offset, whence)
self.reader.reset()
if whence == 0 and offset == 0:
self.writer.reset()
def __getattr__(self, name,
getattr=getattr):
""" Inherit all other methods from the underlying stream.
"""
return getattr(self.stream, name)
# these are needed to make "with codecs.open(...)" work properly
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.stream.close()
###
class StreamRecoder:
""" StreamRecoder instances provide a frontend - backend
view of encoding data.
They use the complete set of APIs returned by the
codecs.lookup() function to implement their task.
Data written to the stream is first decoded into an
intermediate format (which is dependent on the given codec
combination) and then written to the stream using an instance
of the provided Writer class.
In the other direction, data is read from the stream using a
Reader instance and then return encoded data to the caller.
"""
# Optional attributes set by the file wrappers below
data_encoding = 'unknown'
file_encoding = 'unknown'
def __init__(self, stream, encode, decode, Reader, Writer,
errors='strict'):
""" Creates a StreamRecoder instance which implements a two-way
conversion: encode and decode work on the frontend (the
input to .read() and output of .write()) while
Reader and Writer work on the backend (reading and
writing to the stream).
You can use these objects to do transparent direct
recodings from e.g. latin-1 to utf-8 and back.
stream must be a file-like object.
encode, decode must adhere to the Codec interface, Reader,
Writer must be factory functions or classes providing the
StreamReader, StreamWriter interface resp.
encode and decode are needed for the frontend translation,
Reader and Writer for the backend translation. Unicode is
used as intermediate encoding.
Error handling is done in the same way as defined for the
StreamWriter/Readers.
"""
self.stream = stream
self.encode = encode
self.decode = decode
self.reader = Reader(stream, errors)
self.writer = Writer(stream, errors)
self.errors = errors
def read(self, size=-1):
data = self.reader.read(size)
data, bytesencoded = self.encode(data, self.errors)
return data
def readline(self, size=None):
if size is None:
data = self.reader.readline()
else:
data = self.reader.readline(size)
data, bytesencoded = self.encode(data, self.errors)
return data
def readlines(self, sizehint=None):
data = self.reader.read()
data, bytesencoded = self.encode(data, self.errors)
return data.splitlines(1)
def next(self):
""" Return the next decoded line from the input stream."""
data = self.reader.next()
data, bytesencoded = self.encode(data, self.errors)
return data
def __iter__(self):
return self
def write(self, data):
data, bytesdecoded = self.decode(data, self.errors)
return self.writer.write(data)
def writelines(self, list):
data = ''.join(list)
data, bytesdecoded = self.decode(data, self.errors)
return self.writer.write(data)
def reset(self):
self.reader.reset()
self.writer.reset()
def __getattr__(self, name,
getattr=getattr):
""" Inherit all other methods from the underlying stream.
"""
return getattr(self.stream, name)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.stream.close()
### Shortcuts
def open(filename, mode='rb', encoding=None, errors='strict', buffering=1):
""" Open an encoded file using the given mode and return
a wrapped version providing transparent encoding/decoding.
Note: The wrapped version will only accept the object format
defined by the codecs, i.e. Unicode objects for most builtin
codecs. Output is also codec dependent and will usually be
Unicode as well.
Files are always opened in binary mode, even if no binary mode
was specified. This is done to avoid data loss due to encodings
using 8-bit values. The default file mode is 'rb' meaning to
open the file in binary read mode.
encoding specifies the encoding which is to be used for the
file.
errors may be given to define the error handling. It defaults
to 'strict' which causes ValueErrors to be raised in case an
encoding error occurs.
buffering has the same meaning as for the builtin open() API.
It defaults to line buffered.
The returned wrapped file object provides an extra attribute
.encoding which allows querying the used encoding. This
attribute is only available if an encoding was specified as
parameter.
"""
if encoding is not None:
if 'U' in mode:
# No automatic conversion of '\n' is done on reading and writing
mode = mode.strip().replace('U', '')
if mode[:1] not in set('rwa'):
mode = 'r' + mode
if 'b' not in mode:
# Force opening of the file in binary mode
mode = mode + 'b'
file = __builtin__.open(filename, mode, buffering)
if encoding is None:
return file
info = lookup(encoding)
srw = StreamReaderWriter(file, info.streamreader, info.streamwriter, errors)
# Add attributes to simplify introspection
srw.encoding = encoding
return srw
def EncodedFile(file, data_encoding, file_encoding=None, errors='strict'):
""" Return a wrapped version of file which provides transparent
encoding translation.
Strings written to the wrapped file are interpreted according
to the given data_encoding and then written to the original
file as string using file_encoding. The intermediate encoding
will usually be Unicode but depends on the specified codecs.
Strings are read from the file using file_encoding and then
passed back to the caller as string using data_encoding.
If file_encoding is not given, it defaults to data_encoding.
errors may be given to define the error handling. It defaults
to 'strict' which causes ValueErrors to be raised in case an
encoding error occurs.
The returned wrapped file object provides two extra attributes
.data_encoding and .file_encoding which reflect the given
parameters of the same name. The attributes can be used for
introspection by Python programs.
"""
if file_encoding is None:
file_encoding = data_encoding
data_info = lookup(data_encoding)
file_info = lookup(file_encoding)
sr = StreamRecoder(file, data_info.encode, data_info.decode,
file_info.streamreader, file_info.streamwriter, errors)
# Add attributes to simplify introspection
sr.data_encoding = data_encoding
sr.file_encoding = file_encoding
return sr
### Helpers for codec lookup
def getencoder(encoding):
""" Lookup up the codec for the given encoding and return
its encoder function.
Raises a LookupError in case the encoding cannot be found.
"""
return lookup(encoding).encode
def getdecoder(encoding):
""" Lookup up the codec for the given encoding and return
its decoder function.
Raises a LookupError in case the encoding cannot be found.
"""
return lookup(encoding).decode
def getincrementalencoder(encoding):
""" Lookup up the codec for the given encoding and return
its IncrementalEncoder class or factory function.
Raises a LookupError in case the encoding cannot be found
or the codecs doesn't provide an incremental encoder.
"""
encoder = lookup(encoding).incrementalencoder
if encoder is None:
raise LookupError(encoding)
return encoder
def getincrementaldecoder(encoding):
""" Lookup up the codec for the given encoding and return
its IncrementalDecoder class or factory function.
Raises a LookupError in case the encoding cannot be found
or the codecs doesn't provide an incremental decoder.
"""
decoder = lookup(encoding).incrementaldecoder
if decoder is None:
raise LookupError(encoding)
return decoder
def getreader(encoding):
""" Lookup up the codec for the given encoding and return
its StreamReader class or factory function.
Raises a LookupError in case the encoding cannot be found.
"""
return lookup(encoding).streamreader
def getwriter(encoding):
""" Lookup up the codec for the given encoding and return
its StreamWriter class or factory function.
Raises a LookupError in case the encoding cannot be found.
"""
return lookup(encoding).streamwriter
def iterencode(iterator, encoding, errors='strict', **kwargs):
"""
Encoding iterator.
Encodes the input strings from the iterator using an IncrementalEncoder.
errors and kwargs are passed through to the IncrementalEncoder
constructor.
"""
encoder = getincrementalencoder(encoding)(errors, **kwargs)
for input in iterator:
output = encoder.encode(input)
if output:
yield output
output = encoder.encode("", True)
if output:
yield output
def iterdecode(iterator, encoding, errors='strict', **kwargs):
"""
Decoding iterator.
Decodes the input strings from the iterator using an IncrementalDecoder.
errors and kwargs are passed through to the IncrementalDecoder
constructor.
"""
decoder = getincrementaldecoder(encoding)(errors, **kwargs)
for input in iterator:
output = decoder.decode(input)
if output:
yield output
output = decoder.decode("", True)
if output:
yield output
### Helpers for charmap-based codecs
def make_identity_dict(rng):
""" make_identity_dict(rng) -> dict
Return a dictionary where elements of the rng sequence are
mapped to themselves.
"""
res = {}
for i in rng:
res[i]=i
return res
def make_encoding_map(decoding_map):
""" Creates an encoding map from a decoding map.
If a target mapping in the decoding map occurs multiple
times, then that target is mapped to None (undefined mapping),
causing an exception when encountered by the charmap codec
during translation.
One example where this happens is cp875.py which decodes
multiple character to \\u001a.
"""
m = {}
for k,v in decoding_map.items():
if not v in m:
m[v] = k
else:
m[v] = None
return m
### error handlers
try:
strict_errors = lookup_error("strict")
ignore_errors = lookup_error("ignore")
replace_errors = lookup_error("replace")
xmlcharrefreplace_errors = lookup_error("xmlcharrefreplace")
backslashreplace_errors = lookup_error("backslashreplace")
except LookupError:
# In --disable-unicode builds, these error handler are missing
strict_errors = None
ignore_errors = None
replace_errors = None
xmlcharrefreplace_errors = None
backslashreplace_errors = None
# Tell modulefinder that using codecs probably needs the encodings
# package
_false = 0
if _false:
import encodings
### Tests
if __name__ == '__main__':
# Make stdout translate Latin-1 output into UTF-8 output
sys.stdout = EncodedFile(sys.stdout, 'latin-1', 'utf-8')
# Have stdin translate Latin-1 input into UTF-8 input
sys.stdin = EncodedFile(sys.stdin, 'utf-8', 'latin-1')
r"""Utilities to compile possibly incomplete Python source code.
This module provides two interfaces, broadly similar to the builtin
function compile(), which take program text, a filename and a 'mode'
and:
- Return code object if the command is complete and valid
- Return None if the command is incomplete
- Raise SyntaxError, ValueError or OverflowError if the command is a
syntax error (OverflowError and ValueError can be produced by
malformed literals).
Approach:
First, check if the source consists entirely of blank lines and
comments; if so, replace it with 'pass', because the built-in
parser doesn't always do the right thing for these.
Compile three times: as is, with \n, and with \n\n appended. If it
compiles as is, it's complete. If it compiles with one \n appended,
we expect more. If it doesn't compile either way, we compare the
error we get when compiling with \n or \n\n appended. If the errors
are the same, the code is broken. But if the errors are different, we
expect more. Not intuitive; not even guaranteed to hold in future
releases; but this matches the compiler's behavior from Python 1.4
through 2.2, at least.
Caveat:
It is possible (but not likely) that the parser stops parsing with a
successful outcome before reaching the end of the source; in this
case, trailing symbols may be ignored instead of causing an error.
For example, a backslash followed by two newlines may be followed by
arbitrary garbage. This will be fixed once the API for the parser is
better.
The two interfaces are:
compile_command(source, filename, symbol):
Compiles a single command in the manner described above.
CommandCompiler():
Instances of this class have __call__ methods identical in
signature to compile_command; the difference is that if the
instance compiles program text containing a __future__ statement,
the instance 'remembers' and compiles all subsequent program texts
with the statement in force.
The module also provides another class:
Compile():
Instances of this class act like the built-in function compile,
but with 'memory' in the sense described above.
"""
import __future__
_features = [getattr(__future__, fname)
for fname in __future__.all_feature_names]
__all__ = ["compile_command", "Compile", "CommandCompiler"]
PyCF_DONT_IMPLY_DEDENT = 0x200 # Matches pythonrun.h
def _maybe_compile(compiler, source, filename, symbol):
# Check for source consisting of only blank lines and comments
for line in source.split("\n"):
line = line.strip()
if line and line[0] != '#':
break # Leave it alone
else:
if symbol != "eval":
source = "pass" # Replace it with a 'pass' statement
err = err1 = err2 = None
code = code1 = code2 = None
try:
code = compiler(source, filename, symbol)
except SyntaxError, err:
pass
try:
code1 = compiler(source + "\n", filename, symbol)
except SyntaxError, err1:
pass
try:
code2 = compiler(source + "\n\n", filename, symbol)
except SyntaxError, err2:
pass
if code:
return code
if not code1 and repr(err1) == repr(err2):
raise SyntaxError, err1
def _compile(source, filename, symbol):
return compile(source, filename, symbol, PyCF_DONT_IMPLY_DEDENT)
def compile_command(source, filename="<input>", symbol="single"):
r"""Compile a command and determine whether it is incomplete.
Arguments:
source -- the source string; may contain \n characters
filename -- optional filename from which source was read; default
"<input>"
symbol -- optional grammar start symbol; "single" (default) or "eval"
Return value / exceptions raised:
- Return a code object if the command is complete and valid
- Return None if the command is incomplete
- Raise SyntaxError, ValueError or OverflowError if the command is a
syntax error (OverflowError and ValueError can be produced by
malformed literals).
"""
return _maybe_compile(_compile, source, filename, symbol)
class Compile:
"""Instances of this class behave much like the built-in compile
function, but if one is used to compile text containing a future
statement, it "remembers" and compiles all subsequent program texts
with the statement in force."""
def __init__(self):
self.flags = PyCF_DONT_IMPLY_DEDENT
def __call__(self, source, filename, symbol):
codeob = compile(source, filename, symbol, self.flags, 1)
for feature in _features:
if codeob.co_flags & feature.compiler_flag:
self.flags |= feature.compiler_flag
return codeob
class CommandCompiler:
"""Instances of this class have __call__ methods identical in
signature to compile_command; the difference is that if the
instance compiles program text containing a __future__ statement,
the instance 'remembers' and compiles all subsequent program texts
with the statement in force."""
def __init__(self,):
self.compiler = Compile()
def __call__(self, source, filename="<input>", symbol="single"):
r"""Compile a command and determine whether it is incomplete.
Arguments:
source -- the source string; may contain \n characters
filename -- optional filename from which source was read;
default "<input>"
symbol -- optional grammar start symbol; "single" (default) or
"eval"
Return value / exceptions raised:
- Return a code object if the command is complete and valid
- Return None if the command is incomplete
- Raise SyntaxError, ValueError or OverflowError if the command is a
syntax error (OverflowError and ValueError can be produced by
malformed literals).
"""
return _maybe_compile(self.compiler, source, filename, symbol)
'''This module implements specialized container datatypes providing
alternatives to Python's general purpose built-in containers, dict,
list, set, and tuple.
* namedtuple factory function for creating tuple subclasses with named fields
* deque list-like container with fast appends and pops on either end
* Counter dict subclass for counting hashable objects
* OrderedDict dict subclass that remembers the order entries were added
* defaultdict dict subclass that calls a factory function to supply missing values
'''
__all__ = ['Counter', 'deque', 'defaultdict', 'namedtuple', 'OrderedDict']
# For bootstrapping reasons, the collection ABCs are defined in _abcoll.py.
# They should however be considered an integral part of collections.py.
from _abcoll import *
import _abcoll
__all__ += _abcoll.__all__
from _collections import deque, defaultdict
from operator import itemgetter as _itemgetter, eq as _eq
from keyword import iskeyword as _iskeyword
import sys as _sys
import heapq as _heapq
from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
from itertools import imap as _imap
try:
from thread import get_ident as _get_ident
except ImportError:
from dummy_thread import get_ident as _get_ident
################################################################################
### OrderedDict
################################################################################
class OrderedDict(dict):
'Dictionary that remembers insertion order'
# An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get.
# The remaining methods are order-aware.
# Big-O running times for all methods are the same as regular dictionaries.
# The internal self.__map dict maps keys to links in a doubly linked list.
# The circular doubly linked list starts and ends with a sentinel element.
# The sentinel element never gets deleted (this simplifies the algorithm).
# Each link is stored as a list of length three: [PREV, NEXT, KEY].
def __init__(*args, **kwds):
'''Initialize an ordered dictionary. The signature is the same as
regular dictionaries, but keyword arguments are not recommended because
their insertion order is arbitrary.
'''
if not args:
raise TypeError("descriptor '__init__' of 'OrderedDict' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
try:
self.__root
except AttributeError:
self.__root = root = [] # sentinel node
root[:] = [root, root, None]
self.__map = {}
self.__update(*args, **kwds)
def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
'od.__setitem__(i, y) <==> od[i]=y'
# Setting a new item creates a new link at the end of the linked list,
# and the inherited dictionary is updated with the new key/value pair.
if key not in self:
root = self.__root
last = root[0]
last[1] = root[0] = self.__map[key] = [last, root, key]
return dict_setitem(self, key, value)
def __delitem__(self, key, dict_delitem=dict.__delitem__):
'od.__delitem__(y) <==> del od[y]'
# Deleting an existing item uses self.__map to find the link which gets
# removed by updating the links in the predecessor and successor nodes.
dict_delitem(self, key)
link_prev, link_next, _ = self.__map.pop(key)
link_prev[1] = link_next # update link_prev[NEXT]
link_next[0] = link_prev # update link_next[PREV]
def __iter__(self):
'od.__iter__() <==> iter(od)'
# Traverse the linked list in order.
root = self.__root
curr = root[1] # start at the first node
while curr is not root:
yield curr[2] # yield the curr[KEY]
curr = curr[1] # move to next node
def __reversed__(self):
'od.__reversed__() <==> reversed(od)'
# Traverse the linked list in reverse order.
root = self.__root
curr = root[0] # start at the last node
while curr is not root:
yield curr[2] # yield the curr[KEY]
curr = curr[0] # move to previous node
def clear(self):
'od.clear() -> None. Remove all items from od.'
root = self.__root
root[:] = [root, root, None]
self.__map.clear()
dict.clear(self)
# -- the following methods do not depend on the internal structure --
def keys(self):
'od.keys() -> list of keys in od'
return list(self)
def values(self):
'od.values() -> list of values in od'
return [self[key] for key in self]
def items(self):
'od.items() -> list of (key, value) pairs in od'
return [(key, self[key]) for key in self]
def iterkeys(self):
'od.iterkeys() -> an iterator over the keys in od'
return iter(self)
def itervalues(self):
'od.itervalues -> an iterator over the values in od'
for k in self:
yield self[k]
def iteritems(self):
'od.iteritems -> an iterator over the (key, value) pairs in od'
for k in self:
yield (k, self[k])
update = MutableMapping.update
__update = update # let subclasses override update without breaking __init__
__marker = object()
def pop(self, key, default=__marker):
'''od.pop(k[,d]) -> v, remove specified key and return the corresponding
value. If key is not found, d is returned if given, otherwise KeyError
is raised.
'''
if key in self:
result = self[key]
del self[key]
return result
if default is self.__marker:
raise KeyError(key)
return default
def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
if key in self:
return self[key]
self[key] = default
return default
def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false.
'''
if not self:
raise KeyError('dictionary is empty')
key = next(reversed(self) if last else iter(self))
value = self.pop(key)
return key, value
def __repr__(self, _repr_running={}):
'od.__repr__() <==> repr(od)'
call_key = id(self), _get_ident()
if call_key in _repr_running:
return '...'
_repr_running[call_key] = 1
try:
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, self.items())
finally:
del _repr_running[call_key]
def __reduce__(self):
'Return state information for pickling'
items = [[k, self[k]] for k in self]
inst_dict = vars(self).copy()
for k in vars(OrderedDict()):
inst_dict.pop(k, None)
if inst_dict:
return (self.__class__, (items,), inst_dict)
return self.__class__, (items,)
def copy(self):
'od.copy() -> a shallow copy of od'
return self.__class__(self)
@classmethod
def fromkeys(cls, iterable, value=None):
'''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S.
If not specified, the value defaults to None.
'''
self = cls()
for key in iterable:
self[key] = value
return self
def __eq__(self, other):
'''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
while comparison to a regular mapping is order-insensitive.
'''
if isinstance(other, OrderedDict):
return dict.__eq__(self, other) and all(_imap(_eq, self, other))
return dict.__eq__(self, other)
def __ne__(self, other):
'od.__ne__(y) <==> od!=y'
return not self == other
# -- the following methods support python 3.x style dictionary views --
def viewkeys(self):
"od.viewkeys() -> a set-like object providing a view on od's keys"
return KeysView(self)
def viewvalues(self):
"od.viewvalues() -> an object providing a view on od's values"
return ValuesView(self)
def viewitems(self):
"od.viewitems() -> a set-like object providing a view on od's items"
return ItemsView(self)
################################################################################
### namedtuple
################################################################################
_class_template = '''\
class {typename}(tuple):
'{typename}({arg_list})'
__slots__ = ()
_fields = {field_names!r}
def __new__(_cls, {arg_list}):
'Create new instance of {typename}({arg_list})'
return _tuple.__new__(_cls, ({arg_list}))
@classmethod
def _make(cls, iterable, new=tuple.__new__, len=len):
'Make a new {typename} object from a sequence or iterable'
result = new(cls, iterable)
if len(result) != {num_fields:d}:
raise TypeError('Expected {num_fields:d} arguments, got %d' % len(result))
return result
def __repr__(self):
'Return a nicely formatted representation string'
return '{typename}({repr_fmt})' % self
def _asdict(self):
'Return a new OrderedDict which maps field names to their values'
return OrderedDict(zip(self._fields, self))
def _replace(_self, **kwds):
'Return a new {typename} object replacing specified fields with new values'
result = _self._make(map(kwds.pop, {field_names!r}, _self))
if kwds:
raise ValueError('Got unexpected field names: %r' % kwds.keys())
return result
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return tuple(self)
__dict__ = _property(_asdict)
def __getstate__(self):
'Exclude the OrderedDict from pickling'
pass
{field_defs}
'''
_repr_template = '{name}=%r'
_field_template = '''\
{name} = _property(_itemgetter({index:d}), doc='Alias for field number {index:d}')
'''
def namedtuple(typename, field_names, verbose=False, rename=False):
"""Returns a new subclass of tuple with named fields.
>>> Point = namedtuple('Point', ['x', 'y'])
>>> Point.__doc__ # docstring for the new class
'Point(x, y)'
>>> p = Point(11, y=22) # instantiate with positional args or keywords
>>> p[0] + p[1] # indexable like a plain tuple
33
>>> x, y = p # unpack like a regular tuple
>>> x, y
(11, 22)
>>> p.x + p.y # fields also accessible by name
33
>>> d = p._asdict() # convert to a dictionary
>>> d['x']
11
>>> Point(**d) # convert from a dictionary
Point(x=11, y=22)
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
Point(x=100, y=22)
"""
# Validate the field names. At the user's option, either generate an error
# message or automatically replace the field name with a valid name.
if isinstance(field_names, basestring):
field_names = field_names.replace(',', ' ').split()
field_names = map(str, field_names)
typename = str(typename)
if rename:
seen = set()
for index, name in enumerate(field_names):
if (not all(c.isalnum() or c=='_' for c in name)
or _iskeyword(name)
or not name
or name[0].isdigit()
or name.startswith('_')
or name in seen):
field_names[index] = '_%d' % index
seen.add(name)
for name in [typename] + field_names:
if type(name) != str:
raise TypeError('Type names and field names must be strings')
if not all(c.isalnum() or c=='_' for c in name):
raise ValueError('Type names and field names can only contain '
'alphanumeric characters and underscores: %r' % name)
if _iskeyword(name):
raise ValueError('Type names and field names cannot be a '
'keyword: %r' % name)
if name[0].isdigit():
raise ValueError('Type names and field names cannot start with '
'a number: %r' % name)
seen = set()
for name in field_names:
if name.startswith('_') and not rename:
raise ValueError('Field names cannot start with an underscore: '
'%r' % name)
if name in seen:
raise ValueError('Encountered duplicate field name: %r' % name)
seen.add(name)
# Fill-in the class template
class_definition = _class_template.format(
typename = typename,
field_names = tuple(field_names),
num_fields = len(field_names),
arg_list = repr(tuple(field_names)).replace("'", "")[1:-1],
repr_fmt = ', '.join(_repr_template.format(name=name)
for name in field_names),
field_defs = '\n'.join(_field_template.format(index=index, name=name)
for index, name in enumerate(field_names))
)
if verbose:
print class_definition
# Execute the template string in a temporary namespace and support
# tracing utilities by setting a value for frame.f_globals['__name__']
namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
OrderedDict=OrderedDict, _property=property, _tuple=tuple)
try:
exec class_definition in namespace
except SyntaxError as e:
raise SyntaxError(e.message + ':\n' + class_definition)
result = namespace[typename]
# For pickling to work, the __module__ variable needs to be set to the frame
# where the named tuple is created. Bypass this step in environments where
# sys._getframe is not defined (Jython for example) or sys._getframe is not
# defined for arguments greater than 0 (IronPython).
try:
result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
return result
########################################################################
### Counter
########################################################################
class Counter(dict):
'''Dict subclass for counting hashable items. Sometimes called a bag
or multiset. Elements are stored as dictionary keys and their counts
are stored as dictionary values.
>>> c = Counter('abcdeabcdabcaba') # count elements from a string
>>> c.most_common(3) # three most common elements
[('a', 5), ('b', 4), ('c', 3)]
>>> sorted(c) # list all unique elements
['a', 'b', 'c', 'd', 'e']
>>> ''.join(sorted(c.elements())) # list elements with repetitions
'aaaaabbbbcccdde'
>>> sum(c.values()) # total of all counts
15
>>> c['a'] # count of letter 'a'
5
>>> for elem in 'shazam': # update counts from an iterable
... c[elem] += 1 # by adding 1 to each element's count
>>> c['a'] # now there are seven 'a'
7
>>> del c['b'] # remove all 'b'
>>> c['b'] # now there are zero 'b'
0
>>> d = Counter('simsalabim') # make another counter
>>> c.update(d) # add in the second counter
>>> c['a'] # now there are nine 'a'
9
>>> c.clear() # empty the counter
>>> c
Counter()
Note: If a count is set to zero or reduced to zero, it will remain
in the counter until the entry is deleted or the counter is cleared:
>>> c = Counter('aaabbc')
>>> c['b'] -= 2 # reduce the count of 'b' by two
>>> c.most_common() # 'b' is still in, but its count is zero
[('a', 3), ('c', 1), ('b', 0)]
'''
# References:
# http://en.wikipedia.org/wiki/Multiset
# http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
# http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
# http://code.activestate.com/recipes/259174/
# Knuth, TAOCP Vol. II section 4.6.3
def __init__(*args, **kwds):
'''Create a new, empty Counter object. And if given, count elements
from an input iterable. Or, initialize the count from another mapping
of elements to their counts.
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
if not args:
raise TypeError("descriptor '__init__' of 'Counter' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
super(Counter, self).__init__()
self.update(*args, **kwds)
def __missing__(self, key):
'The count of elements not in the Counter is zero.'
# Needed so that self[missing_item] does not raise KeyError
return 0
def most_common(self, n=None):
'''List the n most common elements and their counts from the most
common to the least. If n is None, then list all element counts.
>>> Counter('abcdeabcdabcaba').most_common(3)
[('a', 5), ('b', 4), ('c', 3)]
'''
# Emulate Bag.sortedByCount from Smalltalk
if n is None:
return sorted(self.iteritems(), key=_itemgetter(1), reverse=True)
return _heapq.nlargest(n, self.iteritems(), key=_itemgetter(1))
def elements(self):
'''Iterator over elements repeating each as many times as its count.
>>> c = Counter('ABCABC')
>>> sorted(c.elements())
['A', 'A', 'B', 'B', 'C', 'C']
# Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1
>>> prime_factors = Counter({2: 2, 3: 3, 17: 1})
>>> product = 1
>>> for factor in prime_factors.elements(): # loop over factors
... product *= factor # and multiply them
>>> product
1836
Note, if an element's count has been set to zero or is a negative
number, elements() will ignore it.
'''
# Emulate Bag.do from Smalltalk and Multiset.begin from C++.
return _chain.from_iterable(_starmap(_repeat, self.iteritems()))
# Override dict methods where necessary
@classmethod
def fromkeys(cls, iterable, v=None):
# There is no equivalent method for counters because setting v=1
# means that no element can have a count greater than one.
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
def update(*args, **kwds):
'''Like dict.update() but add counts instead of replacing them.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.update('witch') # add elements from another iterable
>>> d = Counter('watch')
>>> c.update(d) # add elements from another counter
>>> c['h'] # four 'h' in which, witch, and watch
4
'''
# The regular dict.update() operation makes no sense here because the
# replace behavior results in the some of original untouched counts
# being mixed-in with all of the other counts for a mismash that
# doesn't have a straight-forward interpretation in most counting
# contexts. Instead, we implement straight-addition. Both the inputs
# and outputs are allowed to contain zero and negative counts.
if not args:
raise TypeError("descriptor 'update' of 'Counter' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
iterable = args[0] if args else None
if iterable is not None:
if isinstance(iterable, Mapping):
if self:
self_get = self.get
for elem, count in iterable.iteritems():
self[elem] = self_get(elem, 0) + count
else:
super(Counter, self).update(iterable) # fast path when counter is empty
else:
self_get = self.get
for elem in iterable:
self[elem] = self_get(elem, 0) + 1
if kwds:
self.update(kwds)
def subtract(*args, **kwds):
'''Like dict.update() but subtracts counts instead of replacing them.
Counts can be reduced below zero. Both the inputs and outputs are
allowed to contain zero and negative counts.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.subtract('witch') # subtract elements from another iterable
>>> c.subtract(Counter('watch')) # subtract elements from another counter
>>> c['h'] # 2 in which, minus 1 in witch, minus 1 in watch
0
>>> c['w'] # 1 in which, minus 1 in witch, minus 1 in watch
-1
'''
if not args:
raise TypeError("descriptor 'subtract' of 'Counter' object "
"needs an argument")
self = args[0]
args = args[1:]
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
iterable = args[0] if args else None
if iterable is not None:
self_get = self.get
if isinstance(iterable, Mapping):
for elem, count in iterable.items():
self[elem] = self_get(elem, 0) - count
else:
for elem in iterable:
self[elem] = self_get(elem, 0) - 1
if kwds:
self.subtract(kwds)
def copy(self):
'Return a shallow copy.'
return self.__class__(self)
def __reduce__(self):
return self.__class__, (dict(self),)
def __delitem__(self, elem):
'Like dict.__delitem__() but does not raise KeyError for missing values.'
if elem in self:
super(Counter, self).__delitem__(elem)
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
# and at http://en.wikipedia.org/wiki/Multiset
#
# Outputs guaranteed to only include positive counts.
#
# To strip negative and zero counts, add-in an empty counter:
# c += Counter()
def __add__(self, other):
'''Add counts from two counters.
>>> Counter('abbb') + Counter('bcc')
Counter({'b': 4, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
newcount = count + other[elem]
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count > 0:
result[elem] = count
return result
def __sub__(self, other):
''' Subtract count, but keep only results with positive counts.
>>> Counter('abbbc') - Counter('bccd')
Counter({'b': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
newcount = count - other[elem]
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count < 0:
result[elem] = 0 - count
return result
def __or__(self, other):
'''Union is the maximum of value in either of the input counters.
>>> Counter('abbb') | Counter('bcc')
Counter({'b': 3, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
other_count = other[elem]
newcount = other_count if count < other_count else count
if newcount > 0:
result[elem] = newcount
for elem, count in other.items():
if elem not in self and count > 0:
result[elem] = count
return result
def __and__(self, other):
''' Intersection is the minimum of corresponding counts.
>>> Counter('abbb') & Counter('bcc')
Counter({'b': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem, count in self.items():
other_count = other[elem]
newcount = count if count < other_count else other_count
if newcount > 0:
result[elem] = newcount
return result
if __name__ == '__main__':
# verify that instances can be pickled
from cPickle import loads, dumps
Point = namedtuple('Point', 'x, y', True)
p = Point(x=10, y=20)
assert p == loads(dumps(p))
# test and demonstrate ability to override methods
class Point(namedtuple('Point', 'x y')):
__slots__ = ()
@property
def hypot(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
def __str__(self):
return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, self.hypot)
for p in Point(3, 4), Point(14, 5/7.):
print p
class Point(namedtuple('Point', 'x y')):
'Point class with optimized _make() and _replace() without error-checking'
__slots__ = ()
_make = classmethod(tuple.__new__)
def _replace(self, _map=map, **kwds):
return self._make(_map(kwds.get, ('x', 'y'), self))
print Point(11, 22)._replace(x=100)
Point3D = namedtuple('Point3D', Point._fields + ('z',))
print Point3D.__doc__
import doctest
TestResults = namedtuple('TestResults', 'failed attempted')
print TestResults(*doctest.testmod())
"""Conversion functions between RGB and other color systems.
This modules provides two functions for each color system ABC:
rgb_to_abc(r, g, b) --> a, b, c
abc_to_rgb(a, b, c) --> r, g, b
All inputs and outputs are triples of floats in the range [0.0...1.0]
(with the exception of I and Q, which covers a slightly larger range).
Inputs outside the valid range may cause exceptions or invalid outputs.
Supported color systems:
RGB: Red, Green, Blue components
YIQ: Luminance, Chrominance (used by composite video signals)
HLS: Hue, Luminance, Saturation
HSV: Hue, Saturation, Value
"""
# References:
# http://en.wikipedia.org/wiki/YIQ
# http://en.wikipedia.org/wiki/HLS_color_space
# http://en.wikipedia.org/wiki/HSV_color_space
__all__ = ["rgb_to_yiq","yiq_to_rgb","rgb_to_hls","hls_to_rgb",
"rgb_to_hsv","hsv_to_rgb"]
# Some floating point constants
ONE_THIRD = 1.0/3.0
ONE_SIXTH = 1.0/6.0
TWO_THIRD = 2.0/3.0
# YIQ: used by composite video signals (linear combinations of RGB)
# Y: perceived grey level (0.0 == black, 1.0 == white)
# I, Q: color components
def rgb_to_yiq(r, g, b):
y = 0.30*r + 0.59*g + 0.11*b
i = 0.60*r - 0.28*g - 0.32*b
q = 0.21*r - 0.52*g + 0.31*b
return (y, i, q)
def yiq_to_rgb(y, i, q):
r = y + 0.948262*i + 0.624013*q
g = y - 0.276066*i - 0.639810*q
b = y - 1.105450*i + 1.729860*q
if r < 0.0:
r = 0.0
if g < 0.0:
g = 0.0
if b < 0.0:
b = 0.0
if r > 1.0:
r = 1.0
if g > 1.0:
g = 1.0
if b > 1.0:
b = 1.0
return (r, g, b)
# HLS: Hue, Luminance, Saturation
# H: position in the spectrum
# L: color lightness
# S: color saturation
def rgb_to_hls(r, g, b):
maxc = max(r, g, b)
minc = min(r, g, b)
# XXX Can optimize (maxc+minc) and (maxc-minc)
l = (minc+maxc)/2.0
if minc == maxc:
return 0.0, l, 0.0
if l <= 0.5:
s = (maxc-minc) / (maxc+minc)
else:
s = (maxc-minc) / (2.0-maxc-minc)
rc = (maxc-r) / (maxc-minc)
gc = (maxc-g) / (maxc-minc)
bc = (maxc-b) / (maxc-minc)
if r == maxc:
h = bc-gc
elif g == maxc:
h = 2.0+rc-bc
else:
h = 4.0+gc-rc
h = (h/6.0) % 1.0
return h, l, s
def hls_to_rgb(h, l, s):
if s == 0.0:
return l, l, l
if l <= 0.5:
m2 = l * (1.0+s)
else:
m2 = l+s-(l*s)
m1 = 2.0*l - m2
return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD))
def _v(m1, m2, hue):
hue = hue % 1.0
if hue < ONE_SIXTH:
return m1 + (m2-m1)*hue*6.0
if hue < 0.5:
return m2
if hue < TWO_THIRD:
return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0
return m1
# HSV: Hue, Saturation, Value
# H: position in the spectrum
# S: color saturation ("purity")
# V: color brightness
def rgb_to_hsv(r, g, b):
maxc = max(r, g, b)
minc = min(r, g, b)
v = maxc
if minc == maxc:
return 0.0, 0.0, v
s = (maxc-minc) / maxc
rc = (maxc-r) / (maxc-minc)
gc = (maxc-g) / (maxc-minc)
bc = (maxc-b) / (maxc-minc)
if r == maxc:
h = bc-gc
elif g == maxc:
h = 2.0+rc-bc
else:
h = 4.0+gc-rc
h = (h/6.0) % 1.0
return h, s, v
def hsv_to_rgb(h, s, v):
if s == 0.0:
return v, v, v
i = int(h*6.0) # XXX assume int() truncates!
f = (h*6.0) - i
p = v*(1.0 - s)
q = v*(1.0 - s*f)
t = v*(1.0 - s*(1.0-f))
i = i%6
if i == 0:
return v, t, p
if i == 1:
return q, v, p
if i == 2:
return p, v, t
if i == 3:
return p, q, v
if i == 4:
return t, p, v
if i == 5:
return v, p, q
# Cannot get here
"""Execute shell commands via os.popen() and return status, output.
Interface summary:
import commands
outtext = commands.getoutput(cmd)
(exitstatus, outtext) = commands.getstatusoutput(cmd)
outtext = commands.getstatus(file) # returns output of "ls -ld file"
A trailing newline is removed from the output string.
Encapsulates the basic operation:
pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')
text = pipe.read()
sts = pipe.close()
[Note: it would be nice to add functions to interpret the exit status.]
"""
from warnings import warnpy3k
warnpy3k("the commands module has been removed in Python 3.0; "
"use the subprocess module instead", stacklevel=2)
del warnpy3k
__all__ = ["getstatusoutput","getoutput","getstatus"]
# Module 'commands'
#
# Various tools for executing commands and looking at their output and status.
#
# NB This only works (and is only relevant) for UNIX.
# Get 'ls -l' status for an object into a string
#
def getstatus(file):
"""Return output of "ls -ld <file>" in a string."""
import warnings
warnings.warn("commands.getstatus() is deprecated", DeprecationWarning, 2)
return getoutput('ls -ld' + mkarg(file))
# Get the output from a shell command into a string.
# The exit status is ignored; a trailing newline is stripped.
# Assume the command will work with '{ ... ; } 2>&1' around it..
#
def getoutput(cmd):
"""Return output (stdout or stderr) of executing cmd in a shell."""
return getstatusoutput(cmd)[1]
# Ditto but preserving the exit status.
# Returns a pair (sts, output)
#
def getstatusoutput(cmd):
"""Return (status, output) of executing cmd in a shell."""
import os
pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r')
text = pipe.read()
sts = pipe.close()
if sts is None: sts = 0
if text[-1:] == '\n': text = text[:-1]
return sts, text
# Make command argument from directory and pathname (prefix space, add quotes).
#
def mk2arg(head, x):
import os
return mkarg(os.path.join(head, x))
# Make a shell command argument from a string.
# Return a string beginning with a space followed by a shell-quoted
# version of the argument.
# Two strategies: enclose in single quotes if it contains none;
# otherwise, enclose in double quotes and prefix quotable characters
# with backslash.
#
def mkarg(x):
if '\'' not in x:
return ' \'' + x + '\''
s = ' "'
for c in x:
if c in '\\$"`':
s = s + '\\'
s = s + c
s = s + '"'
return s
"""Module/script to byte-compile all .py files to .pyc (or .pyo) files.
When called as a script with arguments, this compiles the directories
given as arguments recursively; the -l option prevents it from
recursing into directories.
Without arguments, if compiles all modules on sys.path, without
recursing into subdirectories. (Even though it should do so for
packages -- for now, you'll have to deal with packages separately.)
See module py_compile for details of the actual byte-compilation.
"""
import os
import sys
import py_compile
import struct
import imp
__all__ = ["compile_dir","compile_file","compile_path"]
def compile_dir(dir, maxlevels=10, ddir=None,
force=0, rx=None, quiet=0):
"""Byte-compile all modules in the given directory tree.
Arguments (only dir is required):
dir: the directory to byte-compile
maxlevels: maximum recursion level (default 10)
ddir: the directory that will be prepended to the path to the
file as it is compiled into each byte-code file.
force: if 1, force compilation, even if timestamps are up-to-date
quiet: if 1, be quiet during compilation
"""
if not quiet:
print 'Listing', dir, '...'
try:
names = os.listdir(dir)
except os.error:
print "Can't list", dir
names = []
names.sort()
success = 1
for name in names:
fullname = os.path.join(dir, name)
if ddir is not None:
dfile = os.path.join(ddir, name)
else:
dfile = None
if not os.path.isdir(fullname):
if not compile_file(fullname, ddir, force, rx, quiet):
success = 0
elif maxlevels > 0 and \
name != os.curdir and name != os.pardir and \
os.path.isdir(fullname) and \
not os.path.islink(fullname):
if not compile_dir(fullname, maxlevels - 1, dfile, force, rx,
quiet):
success = 0
return success
def compile_file(fullname, ddir=None, force=0, rx=None, quiet=0):
"""Byte-compile one file.
Arguments (only fullname is required):
fullname: the file to byte-compile
ddir: if given, the directory name compiled in to the
byte-code file.
force: if 1, force compilation, even if timestamps are up-to-date
quiet: if 1, be quiet during compilation
"""
success = 1
name = os.path.basename(fullname)
if ddir is not None:
dfile = os.path.join(ddir, name)
else:
dfile = None
if rx is not None:
mo = rx.search(fullname)
if mo:
return success
if os.path.isfile(fullname):
head, tail = name[:-3], name[-3:]
if tail == '.py':
if not force:
try:
mtime = int(os.stat(fullname).st_mtime)
expect = struct.pack('<4sl', imp.get_magic(), mtime)
cfile = fullname + (__debug__ and 'c' or 'o')
with open(cfile, 'rb') as chandle:
actual = chandle.read(8)
if expect == actual:
return success
except IOError:
pass
if not quiet:
print 'Compiling', fullname, '...'
try:
ok = py_compile.compile(fullname, None, dfile, True)
except py_compile.PyCompileError,err:
if quiet:
print 'Compiling', fullname, '...'
print err.msg
success = 0
except IOError, e:
print "Sorry", e
success = 0
else:
if ok == 0:
success = 0
return success
def compile_path(skip_curdir=1, maxlevels=0, force=0, quiet=0):
"""Byte-compile all module on sys.path.
Arguments (all optional):
skip_curdir: if true, skip current directory (default true)
maxlevels: max recursion level (default 0)
force: as for compile_dir() (default 0)
quiet: as for compile_dir() (default 0)
"""
success = 1
for dir in sys.path:
if (not dir or dir == os.curdir) and skip_curdir:
print 'Skipping current directory'
else:
success = success and compile_dir(dir, maxlevels, None,
force, quiet=quiet)
return success
def expand_args(args, flist):
"""read names in flist and append to args"""
expanded = args[:]
if flist:
try:
if flist == '-':
fd = sys.stdin
else:
fd = open(flist)
while 1:
line = fd.readline()
if not line:
break
expanded.append(line[:-1])
except IOError:
print "Error reading file list %s" % flist
raise
return expanded
def main():
"""Script main program."""
import getopt
try:
opts, args = getopt.getopt(sys.argv[1:], 'lfqd:x:i:')
except getopt.error, msg:
print msg
print "usage: python compileall.py [-l] [-f] [-q] [-d destdir] " \
"[-x regexp] [-i list] [directory|file ...]"
print
print "arguments: zero or more file and directory names to compile; " \
"if no arguments given, "
print " defaults to the equivalent of -l sys.path"
print
print "options:"
print "-l: don't recurse into subdirectories"
print "-f: force rebuild even if timestamps are up-to-date"
print "-q: output only error messages"
print "-d destdir: directory to prepend to file paths for use in " \
"compile-time tracebacks and in"
print " runtime tracebacks in cases where the source " \
"file is unavailable"
print "-x regexp: skip files matching the regular expression regexp; " \
"the regexp is searched for"
print " in the full path of each file considered for " \
"compilation"
print "-i file: add all the files and directories listed in file to " \
"the list considered for"
print ' compilation; if "-", names are read from stdin'
sys.exit(2)
maxlevels = 10
ddir = None
force = 0
quiet = 0
rx = None
flist = None
for o, a in opts:
if o == '-l': maxlevels = 0
if o == '-d': ddir = a
if o == '-f': force = 1
if o == '-q': quiet = 1
if o == '-x':
import re
rx = re.compile(a)
if o == '-i': flist = a
if ddir:
if len(args) != 1 and not os.path.isdir(args[0]):
print "-d destdir require exactly one directory argument"
sys.exit(2)
success = 1
try:
if args or flist:
try:
if flist:
args = expand_args(args, flist)
except IOError:
success = 0
if success:
for arg in args:
if os.path.isdir(arg):
if not compile_dir(arg, maxlevels, ddir,
force, rx, quiet):
success = 0
else:
if not compile_file(arg, ddir, force, rx, quiet):
success = 0
else:
success = compile_path()
except KeyboardInterrupt:
print "\n[interrupted]"
success = 0
return success
if __name__ == '__main__':
exit_status = int(not main())
sys.exit(exit_status)
"""Configuration file parser.
A setup file consists of sections, lead by a "[section]" header,
and followed by "name: value" entries, with continuations and such in
the style of RFC 822.
The option values can contain format strings which refer to other values in
the same section, or values in a special [DEFAULT] section.
For example:
something: %(dir)s/whatever
would resolve the "%(dir)s" to the value of dir. All reference
expansions are done late, on demand.
Intrinsic defaults can be specified by passing them into the
ConfigParser constructor as a dictionary.
class:
ConfigParser -- responsible for parsing a list of
configuration files, and managing the parsed database.
methods:
__init__(defaults=None)
create the parser and specify a dictionary of intrinsic defaults. The
keys must be strings, the values must be appropriate for %()s string
interpolation. Note that `__name__' is always an intrinsic default;
its value is the section's name.
sections()
return all the configuration section names, sans DEFAULT
has_section(section)
return whether the given section exists
has_option(section, option)
return whether the given option exists in the given section
options(section)
return list of configuration options for the named section
read(filenames)
read and parse the list of named configuration files, given by
name. A single filename is also allowed. Non-existing files
are ignored. Return list of successfully read files.
readfp(fp, filename=None)
read and parse one configuration file, given as a file object.
The filename defaults to fp.name; it is only used in error
messages (if fp has no `name' attribute, the string `<???>' is used).
get(section, option, raw=False, vars=None)
return a string value for the named option. All % interpolations are
expanded in the return values, based on the defaults passed into the
constructor and the DEFAULT section. Additional substitutions may be
provided using the `vars' argument, which must be a dictionary whose
contents override any pre-existing defaults.
getint(section, options)
like get(), but convert value to an integer
getfloat(section, options)
like get(), but convert value to a float
getboolean(section, options)
like get(), but convert value to a boolean (currently case
insensitively defined as 0, false, no, off for False, and 1, true,
yes, on for True). Returns False or True.
items(section, raw=False, vars=None)
return a list of tuples with (name, value) for each option
in the section.
remove_section(section)
remove the given file section and all its options
remove_option(section, option)
remove the given option from the given section
set(section, option, value)
set the given option
write(fp)
write the configuration state in .ini format
"""
try:
from collections import OrderedDict as _default_dict
except ImportError:
# fallback for setup.py which hasn't yet built _collections
_default_dict = dict
import re
__all__ = ["NoSectionError", "DuplicateSectionError", "NoOptionError",
"InterpolationError", "InterpolationDepthError",
"InterpolationSyntaxError", "ParsingError",
"MissingSectionHeaderError",
"ConfigParser", "SafeConfigParser", "RawConfigParser",
"DEFAULTSECT", "MAX_INTERPOLATION_DEPTH"]
DEFAULTSECT = "DEFAULT"
MAX_INTERPOLATION_DEPTH = 10
# exception classes
class Error(Exception):
"""Base class for ConfigParser exceptions."""
def _get_message(self):
"""Getter for 'message'; needed only to override deprecation in
BaseException."""
return self.__message
def _set_message(self, value):
"""Setter for 'message'; needed only to override deprecation in
BaseException."""
self.__message = value
# BaseException.message has been deprecated since Python 2.6. To prevent
# DeprecationWarning from popping up over this pre-existing attribute, use
# a new property that takes lookup precedence.
message = property(_get_message, _set_message)
def __init__(self, msg=''):
self.message = msg
Exception.__init__(self, msg)
def __repr__(self):
return self.message
__str__ = __repr__
class NoSectionError(Error):
"""Raised when no section matches a requested option."""
def __init__(self, section):
Error.__init__(self, 'No section: %r' % (section,))
self.section = section
self.args = (section, )
class DuplicateSectionError(Error):
"""Raised when a section is multiply-created."""
def __init__(self, section):
Error.__init__(self, "Section %r already exists" % section)
self.section = section
self.args = (section, )
class NoOptionError(Error):
"""A requested option was not found."""
def __init__(self, option, section):
Error.__init__(self, "No option %r in section: %r" %
(option, section))
self.option = option
self.section = section
self.args = (option, section)
class InterpolationError(Error):
"""Base class for interpolation-related exceptions."""
def __init__(self, option, section, msg):
Error.__init__(self, msg)
self.option = option
self.section = section
self.args = (option, section, msg)
class InterpolationMissingOptionError(InterpolationError):
"""A string substitution required a setting which was not available."""
def __init__(self, option, section, rawval, reference):
msg = ("Bad value substitution:\n"
"\tsection: [%s]\n"
"\toption : %s\n"
"\tkey : %s\n"
"\trawval : %s\n"
% (section, option, reference, rawval))
InterpolationError.__init__(self, option, section, msg)
self.reference = reference
self.args = (option, section, rawval, reference)
class InterpolationSyntaxError(InterpolationError):
"""Raised when the source text into which substitutions are made
does not conform to the required syntax."""
class InterpolationDepthError(InterpolationError):
"""Raised when substitutions are nested too deeply."""
def __init__(self, option, section, rawval):
msg = ("Value interpolation too deeply recursive:\n"
"\tsection: [%s]\n"
"\toption : %s\n"
"\trawval : %s\n"
% (section, option, rawval))
InterpolationError.__init__(self, option, section, msg)
self.args = (option, section, rawval)
class ParsingError(Error):
"""Raised when a configuration file does not follow legal syntax."""
def __init__(self, filename):
Error.__init__(self, 'File contains parsing errors: %s' % filename)
self.filename = filename
self.errors = []
self.args = (filename, )
def append(self, lineno, line):
self.errors.append((lineno, line))
self.message += '\n\t[line %2d]: %s' % (lineno, line)
class MissingSectionHeaderError(ParsingError):
"""Raised when a key-value pair is found before any section header."""
def __init__(self, filename, lineno, line):
Error.__init__(
self,
'File contains no section headers.\nfile: %s, line: %d\n%r' %
(filename, lineno, line))
self.filename = filename
self.lineno = lineno
self.line = line
self.args = (filename, lineno, line)
class RawConfigParser:
def __init__(self, defaults=None, dict_type=_default_dict,
allow_no_value=False):
self._dict = dict_type
self._sections = self._dict()
self._defaults = self._dict()
if allow_no_value:
self._optcre = self.OPTCRE_NV
else:
self._optcre = self.OPTCRE
if defaults:
for key, value in defaults.items():
self._defaults[self.optionxform(key)] = value
def defaults(self):
return self._defaults
def sections(self):
"""Return a list of section names, excluding [DEFAULT]"""
# self._sections will never have [DEFAULT] in it
return self._sections.keys()
def add_section(self, section):
"""Create a new section in the configuration.
Raise DuplicateSectionError if a section by the specified name
already exists. Raise ValueError if name is DEFAULT or any of it's
case-insensitive variants.
"""
if section.lower() == "default":
raise ValueError, 'Invalid section name: %s' % section
if section in self._sections:
raise DuplicateSectionError(section)
self._sections[section] = self._dict()
def has_section(self, section):
"""Indicate whether the named section is present in the configuration.
The DEFAULT section is not acknowledged.
"""
return section in self._sections
def options(self, section):
"""Return a list of option names for the given section name."""
try:
opts = self._sections[section].copy()
except KeyError:
raise NoSectionError(section)
opts.update(self._defaults)
if '__name__' in opts:
del opts['__name__']
return opts.keys()
def read(self, filenames):
"""Read and parse a filename or a list of filenames.
Files that cannot be opened are silently ignored; this is
designed so that you can specify a list of potential
configuration file locations (e.g. current directory, user's
home directory, systemwide directory), and all existing
configuration files in the list will be read. A single
filename may also be given.
Return list of successfully read files.
"""
if isinstance(filenames, basestring):
filenames = [filenames]
read_ok = []
for filename in filenames:
try:
fp = open(filename)
except IOError:
continue
self._read(fp, filename)
fp.close()
read_ok.append(filename)
return read_ok
def readfp(self, fp, filename=None):
"""Like read() but the argument must be a file-like object.
The `fp' argument must have a `readline' method. Optional
second argument is the `filename', which if not given, is
taken from fp.name. If fp has no `name' attribute, `<???>' is
used.
"""
if filename is None:
try:
filename = fp.name
except AttributeError:
filename = '<???>'
self._read(fp, filename)
def get(self, section, option):
opt = self.optionxform(option)
if section not in self._sections:
if section != DEFAULTSECT:
raise NoSectionError(section)
if opt in self._defaults:
return self._defaults[opt]
else:
raise NoOptionError(option, section)
elif opt in self._sections[section]:
return self._sections[section][opt]
elif opt in self._defaults:
return self._defaults[opt]
else:
raise NoOptionError(option, section)
def items(self, section):
try:
d2 = self._sections[section]
except KeyError:
if section != DEFAULTSECT:
raise NoSectionError(section)
d2 = self._dict()
d = self._defaults.copy()
d.update(d2)
if "__name__" in d:
del d["__name__"]
return d.items()
def _get(self, section, conv, option):
return conv(self.get(section, option))
def getint(self, section, option):
return self._get(section, int, option)
def getfloat(self, section, option):
return self._get(section, float, option)
_boolean_states = {'1': True, 'yes': True, 'true': True, 'on': True,
'0': False, 'no': False, 'false': False, 'off': False}
def getboolean(self, section, option):
v = self.get(section, option)
if v.lower() not in self._boolean_states:
raise ValueError, 'Not a boolean: %s' % v
return self._boolean_states[v.lower()]
def optionxform(self, optionstr):
return optionstr.lower()
def has_option(self, section, option):
"""Check for the existence of a given option in a given section."""
if not section or section == DEFAULTSECT:
option = self.optionxform(option)
return option in self._defaults
elif section not in self._sections:
return False
else:
option = self.optionxform(option)
return (option in self._sections[section]
or option in self._defaults)
def set(self, section, option, value=None):
"""Set an option."""
if not section or section == DEFAULTSECT:
sectdict = self._defaults
else:
try:
sectdict = self._sections[section]
except KeyError:
raise NoSectionError(section)
sectdict[self.optionxform(option)] = value
def write(self, fp):
"""Write an .ini-format representation of the configuration state."""
if self._defaults:
fp.write("[%s]\n" % DEFAULTSECT)
for (key, value) in self._defaults.items():
fp.write("%s = %s\n" % (key, str(value).replace('\n', '\n\t')))
fp.write("\n")
for section in self._sections:
fp.write("[%s]\n" % section)
for (key, value) in self._sections[section].items():
if key == "__name__":
continue
if (value is not None) or (self._optcre == self.OPTCRE):
key = " = ".join((key, str(value).replace('\n', '\n\t')))
fp.write("%s\n" % (key))
fp.write("\n")
def remove_option(self, section, option):
"""Remove an option."""
if not section or section == DEFAULTSECT:
sectdict = self._defaults
else:
try:
sectdict = self._sections[section]
except KeyError:
raise NoSectionError(section)
option = self.optionxform(option)
existed = option in sectdict
if existed:
del sectdict[option]
return existed
def remove_section(self, section):
"""Remove a file section."""
existed = section in self._sections
if existed:
del self._sections[section]
return existed
#
# Regular expressions for parsing section headers and options.
#
SECTCRE = re.compile(
r'\[' # [
r'(?P<header>[^]]+)' # very permissive!
r'\]' # ]
)
OPTCRE = re.compile(
r'(?P<option>[^:=\s][^:=]*)' # very permissive!
r'\s*(?P<vi>[:=])\s*' # any number of space/tab,
# followed by separator
# (either : or =), followed
# by any # space/tab
r'(?P<value>.*)$' # everything up to eol
)
OPTCRE_NV = re.compile(
r'(?P<option>[^:=\s][^:=]*)' # very permissive!
r'\s*(?:' # any number of space/tab,
r'(?P<vi>[:=])\s*' # optionally followed by
# separator (either : or
# =), followed by any #
# space/tab
r'(?P<value>.*))?$' # everything up to eol
)
def _read(self, fp, fpname):
"""Parse a sectioned setup file.
The sections in setup file contains a title line at the top,
indicated by a name in square brackets (`[]'), plus key/value
options lines, indicated by `name: value' format lines.
Continuations are represented by an embedded newline then
leading whitespace. Blank lines, lines beginning with a '#',
and just about everything else are ignored.
"""
cursect = None # None, or a dictionary
optname = None
lineno = 0
e = None # None, or an exception
while True:
line = fp.readline()
if not line:
break
lineno = lineno + 1
# comment or blank line?
if line.strip() == '' or line[0] in '#;':
continue
if line.split(None, 1)[0].lower() == 'rem' and line[0] in "rR":
# no leading whitespace
continue
# continuation line?
if line[0].isspace() and cursect is not None and optname:
value = line.strip()
if value:
cursect[optname].append(value)
# a section header or option header?
else:
# is it a section header?
mo = self.SECTCRE.match(line)
if mo:
sectname = mo.group('header')
if sectname in self._sections:
cursect = self._sections[sectname]
elif sectname == DEFAULTSECT:
cursect = self._defaults
else:
cursect = self._dict()
cursect['__name__'] = sectname
self._sections[sectname] = cursect
# So sections can't start with a continuation line
optname = None
# no section header in the file?
elif cursect is None:
raise MissingSectionHeaderError(fpname, lineno, line)
# an option line?
else:
mo = self._optcre.match(line)
if mo:
optname, vi, optval = mo.group('option', 'vi', 'value')
optname = self.optionxform(optname.rstrip())
# This check is fine because the OPTCRE cannot
# match if it would set optval to None
if optval is not None:
if vi in ('=', ':') and ';' in optval:
# ';' is a comment delimiter only if it follows
# a spacing character
pos = optval.find(';')
if pos != -1 and optval[pos-1].isspace():
optval = optval[:pos]
optval = optval.strip()
# allow empty values
if optval == '""':
optval = ''
cursect[optname] = [optval]
else:
# valueless option handling
cursect[optname] = optval
else:
# a non-fatal parsing error occurred. set up the
# exception but keep going. the exception will be
# raised at the end of the file and will contain a
# list of all bogus lines
if not e:
e = ParsingError(fpname)
e.append(lineno, repr(line))
# if any parsing errors occurred, raise an exception
if e:
raise e
# join the multi-line values collected while reading
all_sections = [self._defaults]
all_sections.extend(self._sections.values())
for options in all_sections:
for name, val in options.items():
if isinstance(val, list):
options[name] = '\n'.join(val)
import UserDict as _UserDict
class _Chainmap(_UserDict.DictMixin):
"""Combine multiple mappings for successive lookups.
For example, to emulate Python's normal lookup sequence:
import __builtin__
pylookup = _Chainmap(locals(), globals(), vars(__builtin__))
"""
def __init__(self, *maps):
self._maps = maps
def __getitem__(self, key):
for mapping in self._maps:
try:
return mapping[key]
except KeyError:
pass
raise KeyError(key)
def keys(self):
result = []
seen = set()
for mapping in self._maps:
for key in mapping:
if key not in seen:
result.append(key)
seen.add(key)
return result
class ConfigParser(RawConfigParser):
def get(self, section, option, raw=False, vars=None):
"""Get an option value for a given section.
If `vars' is provided, it must be a dictionary. The option is looked up
in `vars' (if provided), `section', and in `defaults' in that order.
All % interpolations are expanded in the return values, unless the
optional argument `raw' is true. Values for interpolation keys are
looked up in the same manner as the option.
The section DEFAULT is special.
"""
sectiondict = {}
try:
sectiondict = self._sections[section]
except KeyError:
if section != DEFAULTSECT:
raise NoSectionError(section)
# Update with the entry specific variables
vardict = {}
if vars:
for key, value in vars.items():
vardict[self.optionxform(key)] = value
d = _Chainmap(vardict, sectiondict, self._defaults)
option = self.optionxform(option)
try:
value = d[option]
except KeyError:
raise NoOptionError(option, section)
if raw or value is None:
return value
else:
return self._interpolate(section, option, value, d)
def items(self, section, raw=False, vars=None):
"""Return a list of tuples with (name, value) for each option
in the section.
All % interpolations are expanded in the return values, based on the
defaults passed into the constructor, unless the optional argument
`raw' is true. Additional substitutions may be provided using the
`vars' argument, which must be a dictionary whose contents overrides
any pre-existing defaults.
The section DEFAULT is special.
"""
d = self._defaults.copy()
try:
d.update(self._sections[section])
except KeyError:
if section != DEFAULTSECT:
raise NoSectionError(section)
# Update with the entry specific variables
if vars:
for key, value in vars.items():
d[self.optionxform(key)] = value
options = d.keys()
if "__name__" in options:
options.remove("__name__")
if raw:
return [(option, d[option])
for option in options]
else:
return [(option, self._interpolate(section, option, d[option], d))
for option in options]
def _interpolate(self, section, option, rawval, vars):
# do the string interpolation
value = rawval
depth = MAX_INTERPOLATION_DEPTH
while depth: # Loop through this until it's done
depth -= 1
if value and "%(" in value:
value = self._KEYCRE.sub(self._interpolation_replace, value)
try:
value = value % vars
except KeyError, e:
raise InterpolationMissingOptionError(
option, section, rawval, e.args[0])
else:
break
if value and "%(" in value:
raise InterpolationDepthError(option, section, rawval)
return value
_KEYCRE = re.compile(r"%\(([^)]*)\)s|.")
def _interpolation_replace(self, match):
s = match.group(1)
if s is None:
return match.group()
else:
return "%%(%s)s" % self.optionxform(s)
class SafeConfigParser(ConfigParser):
def _interpolate(self, section, option, rawval, vars):
# do the string interpolation
L = []
self._interpolate_some(option, L, rawval, section, vars, 1)
return ''.join(L)
_interpvar_re = re.compile(r"%\(([^)]+)\)s")
def _interpolate_some(self, option, accum, rest, section, map, depth):
if depth > MAX_INTERPOLATION_DEPTH:
raise InterpolationDepthError(option, section, rest)
while rest:
p = rest.find("%")
if p < 0:
accum.append(rest)
return
if p > 0:
accum.append(rest[:p])
rest = rest[p:]
# p is no longer used
c = rest[1:2]
if c == "%":
accum.append("%")
rest = rest[2:]
elif c == "(":
m = self._interpvar_re.match(rest)
if m is None:
raise InterpolationSyntaxError(option, section,
"bad interpolation variable reference %r" % rest)
var = self.optionxform(m.group(1))
rest = rest[m.end():]
try:
v = map[var]
except KeyError:
raise InterpolationMissingOptionError(
option, section, rest, var)
if "%" in v:
self._interpolate_some(option, accum, v,
section, map, depth + 1)
else:
accum.append(v)
else:
raise InterpolationSyntaxError(
option, section,
"'%%' must be followed by '%%' or '(', found: %r" % (rest,))
def set(self, section, option, value=None):
"""Set an option. Extend ConfigParser.set: check for string values."""
# The only legal non-string value if we allow valueless
# options is None, so we need to check if the value is a
# string if:
# - we do not allow valueless options, or
# - we allow valueless options but the value is not None
if self._optcre is self.OPTCRE or value:
if not isinstance(value, basestring):
raise TypeError("option values must be strings")
if value is not None:
# check for bad percent signs:
# first, replace all "good" interpolations
tmp_value = value.replace('%%', '')
tmp_value = self._interpvar_re.sub('', tmp_value)
# then, check if there's a lone percent sign left
if '%' in tmp_value:
raise ValueError("invalid interpolation syntax in %r at "
"position %d" % (value, tmp_value.find('%')))
ConfigParser.set(self, section, option, value)
"""Utilities for with-statement contexts. See PEP 343."""
import sys
from functools import wraps
from warnings import warn
__all__ = ["contextmanager", "nested", "closing"]
class GeneratorContextManager(object):
"""Helper for @contextmanager decorator."""
def __init__(self, gen):
self.gen = gen
def __enter__(self):
try:
return self.gen.next()
except StopIteration:
raise RuntimeError("generator didn't yield")
def __exit__(self, type, value, traceback):
if type is None:
try:
self.gen.next()
except StopIteration:
return
else:
raise RuntimeError("generator didn't stop")
else:
if value is None:
# Need to force instantiation so we can reliably
# tell if we get the same exception back
value = type()
try:
self.gen.throw(type, value, traceback)
raise RuntimeError("generator didn't stop after throw()")
except StopIteration, exc:
# Suppress the exception *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
# raised inside the "with" statement from being suppressed
return exc is not value
except:
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise
# an exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so this
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
#
if sys.exc_info()[1] is not value:
raise
def contextmanager(func):
"""@contextmanager decorator.
Typical usage:
@contextmanager
def some_generator(<arguments>):
<setup>
try:
yield <value>
finally:
<cleanup>
This makes this:
with some_generator(<arguments>) as <variable>:
<body>
equivalent to this:
<setup>
try:
<variable> = <value>
<body>
finally:
<cleanup>
"""
@wraps(func)
def helper(*args, **kwds):
return GeneratorContextManager(func(*args, **kwds))
return helper
@contextmanager
def nested(*managers):
"""Combine multiple context managers into a single nested context manager.
This function has been deprecated in favour of the multiple manager form
of the with statement.
The one advantage of this function over the multiple manager form of the
with statement is that argument unpacking allows it to be
used with a variable number of context managers as follows:
with nested(*managers):
do_something()
"""
warn("With-statements now directly support multiple context managers",
DeprecationWarning, 3)
exits = []
vars = []
exc = (None, None, None)
try:
for mgr in managers:
exit = mgr.__exit__
enter = mgr.__enter__
vars.append(enter())
exits.append(exit)
yield vars
except:
exc = sys.exc_info()
finally:
while exits:
exit = exits.pop()
try:
if exit(*exc):
exc = (None, None, None)
except:
exc = sys.exc_info()
if exc != (None, None, None):
# Don't rely on sys.exc_info() still containing
# the right information. Another exception may
# have been raised and caught by an exit method
raise exc[0], exc[1], exc[2]
class closing(object):
"""Context to automatically close something at the end of a block.
Code like this:
with closing(<module>.open(<arguments>)) as f:
<block>
is equivalent to this:
f = <module>.open(<arguments>)
try:
<block>
finally:
f.close()
"""
def __init__(self, thing):
self.thing = thing
def __enter__(self):
return self.thing
def __exit__(self, *exc_info):
self.thing.close()
####
# Copyright 2000 by Timothy O'Malley <[email protected]>
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software
# and its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of
# Timothy O'Malley not be used in advertising or publicity
# pertaining to distribution of the software without specific, written
# prior permission.
#
# Timothy O'Malley DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
# SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS, IN NO EVENT SHALL Timothy O'Malley BE LIABLE FOR
# ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
#
####
#
# Id: Cookie.py,v 2.29 2000/08/23 05:28:49 timo Exp
# by Timothy O'Malley <[email protected]>
#
# Cookie.py is a Python module for the handling of HTTP
# cookies as a Python dictionary. See RFC 2109 for more
# information on cookies.
#
# The original idea to treat Cookies as a dictionary came from
# Dave Mitchell ([email protected]) in 1995, when he released the
# first version of nscookie.py.
#
####
r"""
Here's a sample session to show how to use this module.
At the moment, this is the only documentation.
The Basics
----------
Importing is easy..
>>> import Cookie
Most of the time you start by creating a cookie. Cookies come in
three flavors, each with slightly different encoding semantics, but
more on that later.
>>> C = Cookie.SimpleCookie()
>>> C = Cookie.SerialCookie()
>>> C = Cookie.SmartCookie()
[Note: Long-time users of Cookie.py will remember using
Cookie.Cookie() to create a Cookie object. Although deprecated, it
is still supported by the code. See the Backward Compatibility notes
for more information.]
Once you've created your Cookie, you can add values just as if it were
a dictionary.
>>> C = Cookie.SmartCookie()
>>> C["fig"] = "newton"
>>> C["sugar"] = "wafer"
>>> C.output()
'Set-Cookie: fig=newton\r\nSet-Cookie: sugar=wafer'
Notice that the printable representation of a Cookie is the
appropriate format for a Set-Cookie: header. This is the
default behavior. You can change the header and printed
attributes by using the .output() function
>>> C = Cookie.SmartCookie()
>>> C["rocky"] = "road"
>>> C["rocky"]["path"] = "/cookie"
>>> print C.output(header="Cookie:")
Cookie: rocky=road; Path=/cookie
>>> print C.output(attrs=[], header="Cookie:")
Cookie: rocky=road
The load() method of a Cookie extracts cookies from a string. In a
CGI script, you would use this method to extract the cookies from the
HTTP_COOKIE environment variable.
>>> C = Cookie.SmartCookie()
>>> C.load("chips=ahoy; vienna=finger")
>>> C.output()
'Set-Cookie: chips=ahoy\r\nSet-Cookie: vienna=finger'
The load() method is darn-tootin smart about identifying cookies
within a string. Escaped quotation marks, nested semicolons, and other
such trickeries do not confuse it.
>>> C = Cookie.SmartCookie()
>>> C.load('keebler="E=everybody; L=\\"Loves\\"; fudge=\\012;";')
>>> print C
Set-Cookie: keebler="E=everybody; L=\"Loves\"; fudge=\012;"
Each element of the Cookie also supports all of the RFC 2109
Cookie attributes. Here's an example which sets the Path
attribute.
>>> C = Cookie.SmartCookie()
>>> C["oreo"] = "doublestuff"
>>> C["oreo"]["path"] = "/"
>>> print C
Set-Cookie: oreo=doublestuff; Path=/
Each dictionary element has a 'value' attribute, which gives you
back the value associated with the key.
>>> C = Cookie.SmartCookie()
>>> C["twix"] = "none for you"
>>> C["twix"].value
'none for you'
A Bit More Advanced
-------------------
As mentioned before, there are three different flavors of Cookie
objects, each with different encoding/decoding semantics. This
section briefly discusses the differences.
SimpleCookie
The SimpleCookie expects that all values should be standard strings.
Just to be sure, SimpleCookie invokes the str() builtin to convert
the value to a string, when the values are set dictionary-style.
>>> C = Cookie.SimpleCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
'7'
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number=7\r\nSet-Cookie: string=seven'
SerialCookie
The SerialCookie expects that all values should be serialized using
cPickle (or pickle, if cPickle isn't available). As a result of
serializing, SerialCookie can save almost any Python object to a
value, and recover the exact same object when the cookie has been
returned. (SerialCookie can yield some strange-looking cookie
values, however.)
>>> C = Cookie.SerialCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
7
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number="I7\\012."\r\nSet-Cookie: string="S\'seven\'\\012p1\\012."'
Be warned, however, if SerialCookie cannot de-serialize a value (because
it isn't a valid pickle'd object), IT WILL RAISE AN EXCEPTION.
SmartCookie
The SmartCookie combines aspects of each of the other two flavors.
When setting a value in a dictionary-fashion, the SmartCookie will
serialize (ala cPickle) the value *if and only if* it isn't a
Python string. String objects are *not* serialized. Similarly,
when the load() method parses out values, it attempts to de-serialize
the value. If it fails, then it fallsback to treating the value
as a string.
>>> C = Cookie.SmartCookie()
>>> C["number"] = 7
>>> C["string"] = "seven"
>>> C["number"].value
7
>>> C["string"].value
'seven'
>>> C.output()
'Set-Cookie: number="I7\\012."\r\nSet-Cookie: string=seven'
Backwards Compatibility
-----------------------
In order to keep compatibility with earlier versions of Cookie.py,
it is still possible to use Cookie.Cookie() to create a Cookie. In
fact, this simply returns a SmartCookie.
>>> C = Cookie.Cookie()
>>> print C.__class__.__name__
SmartCookie
Finis.
""" #"
# ^
# |----helps out font-lock
#
# Import our required modules
#
import string
try:
from cPickle import dumps, loads
except ImportError:
from pickle import dumps, loads
import re, warnings
__all__ = ["CookieError","BaseCookie","SimpleCookie","SerialCookie",
"SmartCookie","Cookie"]
_nulljoin = ''.join
_semispacejoin = '; '.join
_spacejoin = ' '.join
#
# Define an exception visible to External modules
#
class CookieError(Exception):
pass
# These quoting routines conform to the RFC2109 specification, which in
# turn references the character definitions from RFC2068. They provide
# a two-way quoting algorithm. Any non-text character is translated
# into a 4 character sequence: a forward-slash followed by the
# three-digit octal equivalent of the character. Any '\' or '"' is
# quoted with a preceding '\' slash.
#
# These are taken from RFC2068 and RFC2109.
# _LegalChars is the list of chars which don't require "'s
# _Translator hash-table for fast quoting
#
_LegalChars = string.ascii_letters + string.digits + "!#$%&'*+-.^_`|~"
_Translator = {
'\000' : '\\000', '\001' : '\\001', '\002' : '\\002',
'\003' : '\\003', '\004' : '\\004', '\005' : '\\005',
'\006' : '\\006', '\007' : '\\007', '\010' : '\\010',
'\011' : '\\011', '\012' : '\\012', '\013' : '\\013',
'\014' : '\\014', '\015' : '\\015', '\016' : '\\016',
'\017' : '\\017', '\020' : '\\020', '\021' : '\\021',
'\022' : '\\022', '\023' : '\\023', '\024' : '\\024',
'\025' : '\\025', '\026' : '\\026', '\027' : '\\027',
'\030' : '\\030', '\031' : '\\031', '\032' : '\\032',
'\033' : '\\033', '\034' : '\\034', '\035' : '\\035',
'\036' : '\\036', '\037' : '\\037',
# Because of the way browsers really handle cookies (as opposed
# to what the RFC says) we also encode , and ;
',' : '\\054', ';' : '\\073',
'"' : '\\"', '\\' : '\\\\',
'\177' : '\\177', '\200' : '\\200', '\201' : '\\201',
'\202' : '\\202', '\203' : '\\203', '\204' : '\\204',
'\205' : '\\205', '\206' : '\\206', '\207' : '\\207',
'\210' : '\\210', '\211' : '\\211', '\212' : '\\212',
'\213' : '\\213', '\214' : '\\214', '\215' : '\\215',
'\216' : '\\216', '\217' : '\\217', '\220' : '\\220',
'\221' : '\\221', '\222' : '\\222', '\223' : '\\223',
'\224' : '\\224', '\225' : '\\225', '\226' : '\\226',
'\227' : '\\227', '\230' : '\\230', '\231' : '\\231',
'\232' : '\\232', '\233' : '\\233', '\234' : '\\234',
'\235' : '\\235', '\236' : '\\236', '\237' : '\\237',
'\240' : '\\240', '\241' : '\\241', '\242' : '\\242',
'\243' : '\\243', '\244' : '\\244', '\245' : '\\245',
'\246' : '\\246', '\247' : '\\247', '\250' : '\\250',
'\251' : '\\251', '\252' : '\\252', '\253' : '\\253',
'\254' : '\\254', '\255' : '\\255', '\256' : '\\256',
'\257' : '\\257', '\260' : '\\260', '\261' : '\\261',
'\262' : '\\262', '\263' : '\\263', '\264' : '\\264',
'\265' : '\\265', '\266' : '\\266', '\267' : '\\267',
'\270' : '\\270', '\271' : '\\271', '\272' : '\\272',
'\273' : '\\273', '\274' : '\\274', '\275' : '\\275',
'\276' : '\\276', '\277' : '\\277', '\300' : '\\300',
'\301' : '\\301', '\302' : '\\302', '\303' : '\\303',
'\304' : '\\304', '\305' : '\\305', '\306' : '\\306',
'\307' : '\\307', '\310' : '\\310', '\311' : '\\311',
'\312' : '\\312', '\313' : '\\313', '\314' : '\\314',
'\315' : '\\315', '\316' : '\\316', '\317' : '\\317',
'\320' : '\\320', '\321' : '\\321', '\322' : '\\322',
'\323' : '\\323', '\324' : '\\324', '\325' : '\\325',
'\326' : '\\326', '\327' : '\\327', '\330' : '\\330',
'\331' : '\\331', '\332' : '\\332', '\333' : '\\333',
'\334' : '\\334', '\335' : '\\335', '\336' : '\\336',
'\337' : '\\337', '\340' : '\\340', '\341' : '\\341',
'\342' : '\\342', '\343' : '\\343', '\344' : '\\344',
'\345' : '\\345', '\346' : '\\346', '\347' : '\\347',
'\350' : '\\350', '\351' : '\\351', '\352' : '\\352',
'\353' : '\\353', '\354' : '\\354', '\355' : '\\355',
'\356' : '\\356', '\357' : '\\357', '\360' : '\\360',
'\361' : '\\361', '\362' : '\\362', '\363' : '\\363',
'\364' : '\\364', '\365' : '\\365', '\366' : '\\366',
'\367' : '\\367', '\370' : '\\370', '\371' : '\\371',
'\372' : '\\372', '\373' : '\\373', '\374' : '\\374',
'\375' : '\\375', '\376' : '\\376', '\377' : '\\377'
}
_idmap = ''.join(chr(x) for x in xrange(256))
def _quote(str, LegalChars=_LegalChars,
idmap=_idmap, translate=string.translate):
#
# If the string does not need to be double-quoted,
# then just return the string. Otherwise, surround
# the string in doublequotes and precede quote (with a \)
# special characters.
#
if "" == translate(str, idmap, LegalChars):
return str
else:
return '"' + _nulljoin( map(_Translator.get, str, str) ) + '"'
# end _quote
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
_QuotePatt = re.compile(r"[\\].")
def _unquote(str):
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if len(str) < 2:
return str
if str[0] != '"' or str[-1] != '"':
return str
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
str = str[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
i = 0
n = len(str)
res = []
while 0 <= i < n:
Omatch = _OctalPatt.search(str, i)
Qmatch = _QuotePatt.search(str, i)
if not Omatch and not Qmatch: # Neither matched
res.append(str[i:])
break
# else:
j = k = -1
if Omatch: j = Omatch.start(0)
if Qmatch: k = Qmatch.start(0)
if Qmatch and ( not Omatch or k < j ): # QuotePatt matched
res.append(str[i:k])
res.append(str[k+1])
i = k+2
else: # OctalPatt matched
res.append(str[i:j])
res.append( chr( int(str[j+1:j+4], 8) ) )
i = j+4
return _nulljoin(res)
# end _unquote
# The _getdate() routine is used to set the expiration time in
# the cookie's HTTP header. By default, _getdate() returns the
# current time in the appropriate "expires" format for a
# Set-Cookie header. The one optional argument is an offset from
# now, in seconds. For example, an offset of -3600 means "one hour ago".
# The offset may be a floating point number.
#
_weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
_monthname = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
def _getdate(future=0, weekdayname=_weekdayname, monthname=_monthname):
from time import gmtime, time
now = time()
year, month, day, hh, mm, ss, wd, y, z = gmtime(now + future)
return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % \
(weekdayname[wd], day, monthname[month], year, hh, mm, ss)
#
# A class to hold ONE key,value pair.
# In a cookie, each such pair may have several attributes.
# so this class is used to keep the attributes associated
# with the appropriate key,value pair.
# This class also includes a coded_value attribute, which
# is used to hold the network representation of the
# value. This is most useful when Python objects are
# pickled for network transit.
#
class Morsel(dict):
# RFC 2109 lists these attributes as reserved:
# path comment domain
# max-age secure version
#
# For historical reasons, these attributes are also reserved:
# expires
#
# This is an extension from Microsoft:
# httponly
#
# This dictionary provides a mapping from the lowercase
# variant on the left to the appropriate traditional
# formatting on the right.
_reserved = { "expires" : "expires",
"path" : "Path",
"comment" : "Comment",
"domain" : "Domain",
"max-age" : "Max-Age",
"secure" : "secure",
"httponly" : "httponly",
"version" : "Version",
}
_flags = {'secure', 'httponly'}
def __init__(self):
# Set defaults
self.key = self.value = self.coded_value = None
# Set default attributes
for K in self._reserved:
dict.__setitem__(self, K, "")
# end __init__
def __setitem__(self, K, V):
K = K.lower()
if not K in self._reserved:
raise CookieError("Invalid Attribute %s" % K)
dict.__setitem__(self, K, V)
# end __setitem__
def isReservedKey(self, K):
return K.lower() in self._reserved
# end isReservedKey
def set(self, key, val, coded_val,
LegalChars=_LegalChars,
idmap=_idmap, translate=string.translate):
# First we verify that the key isn't a reserved word
# Second we make sure it only contains legal characters
if key.lower() in self._reserved:
raise CookieError("Attempt to set a reserved key: %s" % key)
if "" != translate(key, idmap, LegalChars):
raise CookieError("Illegal key value: %s" % key)
# It's a good key, so save it.
self.key = key
self.value = val
self.coded_value = coded_val
# end set
def output(self, attrs=None, header = "Set-Cookie:"):
return "%s %s" % ( header, self.OutputString(attrs) )
__str__ = output
def __repr__(self):
return '<%s: %s=%s>' % (self.__class__.__name__,
self.key, repr(self.value) )
def js_output(self, attrs=None):
# Print javascript
return """
<script type="text/javascript">
<!-- begin hiding
document.cookie = \"%s\";
// end hiding -->
</script>
""" % ( self.OutputString(attrs).replace('"',r'\"'), )
# end js_output()
def OutputString(self, attrs=None):
# Build up our result
#
result = []
RA = result.append
# First, the key=value pair
RA("%s=%s" % (self.key, self.coded_value))
# Now add any defined attributes
if attrs is None:
attrs = self._reserved
items = self.items()
items.sort()
for K,V in items:
if V == "": continue
if K not in attrs: continue
if K == "expires" and type(V) == type(1):
RA("%s=%s" % (self._reserved[K], _getdate(V)))
elif K == "max-age" and type(V) == type(1):
RA("%s=%d" % (self._reserved[K], V))
elif K == "secure":
RA(str(self._reserved[K]))
elif K == "httponly":
RA(str(self._reserved[K]))
else:
RA("%s=%s" % (self._reserved[K], V))
# Return the result
return _semispacejoin(result)
# end OutputString
# end Morsel class
#
# Pattern for finding cookie
#
# This used to be strict parsing based on the RFC2109 and RFC2068
# specifications. I have since discovered that MSIE 3.0x doesn't
# follow the character rules outlined in those specs. As a
# result, the parsing rules here are less strict.
#
_LegalKeyChars = r"\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\="
_LegalValueChars = _LegalKeyChars + r"\[\]"
_CookiePattern = re.compile(
r"(?x)" # This is a Verbose pattern
r"\s*" # Optional whitespace at start of cookie
r"(?P<key>" # Start of group 'key'
"["+ _LegalKeyChars +"]+?" # Any word of at least one letter, nongreedy
r")" # End of group 'key'
r"(" # Optional group: there may not be a value.
r"\s*=\s*" # Equal Sign
r"(?P<val>" # Start of group 'val'
r'"(?:[^\\"]|\\.)*"' # Any doublequoted string
r"|" # or
r"\w{3},\s[\s\w\d-]{9,11}\s[\d:]{8}\sGMT" # Special case for "expires" attr
r"|" # or
"["+ _LegalValueChars +"]*" # Any word or empty string
r")" # End of group 'val'
r")?" # End of optional value group
r"\s*" # Any number of spaces.
r"(\s+|;|$)" # Ending either at space, semicolon, or EOS.
)
# At long last, here is the cookie class.
# Using this class is almost just like using a dictionary.
# See this module's docstring for example usage.
#
class BaseCookie(dict):
# A container class for a set of Morsels
#
def value_decode(self, val):
"""real_value, coded_value = value_decode(STRING)
Called prior to setting a cookie's value from the network
representation. The VALUE is the value read from HTTP
header.
Override this function to modify the behavior of cookies.
"""
return val, val
# end value_encode
def value_encode(self, val):
"""real_value, coded_value = value_encode(VALUE)
Called prior to setting a cookie's value from the dictionary
representation. The VALUE is the value being assigned.
Override this function to modify the behavior of cookies.
"""
strval = str(val)
return strval, strval
# end value_encode
def __init__(self, input=None):
if input: self.load(input)
# end __init__
def __set(self, key, real_value, coded_value):
"""Private method for setting a cookie's value"""
M = self.get(key, Morsel())
M.set(key, real_value, coded_value)
dict.__setitem__(self, key, M)
# end __set
def __setitem__(self, key, value):
"""Dictionary style assignment."""
if isinstance(value, Morsel):
# allow assignment of constructed Morsels (e.g. for pickling)
dict.__setitem__(self, key, value)
else:
rval, cval = self.value_encode(value)
self.__set(key, rval, cval)
# end __setitem__
def output(self, attrs=None, header="Set-Cookie:", sep="\015\012"):
"""Return a string suitable for HTTP."""
result = []
items = self.items()
items.sort()
for K,V in items:
result.append( V.output(attrs, header) )
return sep.join(result)
# end output
__str__ = output
def __repr__(self):
L = []
items = self.items()
items.sort()
for K,V in items:
L.append( '%s=%s' % (K,repr(V.value) ) )
return '<%s: %s>' % (self.__class__.__name__, _spacejoin(L))
def js_output(self, attrs=None):
"""Return a string suitable for JavaScript."""
result = []
items = self.items()
items.sort()
for K,V in items:
result.append( V.js_output(attrs) )
return _nulljoin(result)
# end js_output
def load(self, rawdata):
"""Load cookies from a string (presumably HTTP_COOKIE) or
from a dictionary. Loading cookies from a dictionary 'd'
is equivalent to calling:
map(Cookie.__setitem__, d.keys(), d.values())
"""
if type(rawdata) == type(""):
self.__ParseString(rawdata)
else:
# self.update() wouldn't call our custom __setitem__
for k, v in rawdata.items():
self[k] = v
return
# end load()
def __ParseString(self, str, patt=_CookiePattern):
i = 0 # Our starting point
n = len(str) # Length of string
M = None # current morsel
while 0 <= i < n:
# Start looking for a cookie
match = patt.match(str, i)
if not match: break # No more cookies
K,V = match.group("key"), match.group("val")
i = match.end(0)
# Parse the key, value in case it's metainfo
if K[0] == "$":
# We ignore attributes which pertain to the cookie
# mechanism as a whole. See RFC 2109.
# (Does anyone care?)
if M:
M[ K[1:] ] = V
elif K.lower() in Morsel._reserved:
if M:
if V is None:
if K.lower() in Morsel._flags:
M[K] = True
else:
M[K] = _unquote(V)
elif V is not None:
rval, cval = self.value_decode(V)
self.__set(K, rval, cval)
M = self[K]
# end __ParseString
# end BaseCookie class
class SimpleCookie(BaseCookie):
"""SimpleCookie
SimpleCookie supports strings as cookie values. When setting
the value using the dictionary assignment notation, SimpleCookie
calls the builtin str() to convert the value to a string. Values
received from HTTP are kept as strings.
"""
def value_decode(self, val):
return _unquote( val ), val
def value_encode(self, val):
strval = str(val)
return strval, _quote( strval )
# end SimpleCookie
class SerialCookie(BaseCookie):
"""SerialCookie
SerialCookie supports arbitrary objects as cookie values. All
values are serialized (using cPickle) before being sent to the
client. All incoming values are assumed to be valid Pickle
representations. IF AN INCOMING VALUE IS NOT IN A VALID PICKLE
FORMAT, THEN AN EXCEPTION WILL BE RAISED.
Note: Large cookie values add overhead because they must be
retransmitted on every HTTP transaction.
Note: HTTP has a 2k limit on the size of a cookie. This class
does not check for this limit, so be careful!!!
"""
def __init__(self, input=None):
warnings.warn("SerialCookie class is insecure; do not use it",
DeprecationWarning)
BaseCookie.__init__(self, input)
# end __init__
def value_decode(self, val):
# This could raise an exception!
return loads( _unquote(val) ), val
def value_encode(self, val):
return val, _quote( dumps(val) )
# end SerialCookie
class SmartCookie(BaseCookie):
"""SmartCookie
SmartCookie supports arbitrary objects as cookie values. If the
object is a string, then it is quoted. If the object is not a
string, however, then SmartCookie will use cPickle to serialize
the object into a string representation.
Note: Large cookie values add overhead because they must be
retransmitted on every HTTP transaction.
Note: HTTP has a 2k limit on the size of a cookie. This class
does not check for this limit, so be careful!!!
"""
def __init__(self, input=None):
warnings.warn("Cookie/SmartCookie class is insecure; do not use it",
DeprecationWarning)
BaseCookie.__init__(self, input)
# end __init__
def value_decode(self, val):
strval = _unquote(val)
try:
return loads(strval), val
except:
return strval, val
def value_encode(self, val):
if type(val) == type(""):
return val, _quote(val)
else:
return val, _quote( dumps(val) )
# end SmartCookie
###########################################################
# Backwards Compatibility: Don't break any existing code!
# We provide Cookie() as an alias for SmartCookie()
Cookie = SmartCookie
#
###########################################################
def _test():
import doctest, Cookie
return doctest.testmod(Cookie)
if __name__ == "__main__":
_test()
#Local Variables:
#tab-width: 4
#end:
r"""HTTP cookie handling for web clients.
This module has (now fairly distant) origins in Gisle Aas' Perl module
HTTP::Cookies, from the libwww-perl library.
Docstrings, comments and debug strings in this code refer to the
attributes of the HTTP cookie system as cookie-attributes, to distinguish
them clearly from Python attributes.
Class diagram (note that BSDDBCookieJar and the MSIE* classes are not
distributed with the Python standard library, but are available from
http://wwwsearch.sf.net/):
CookieJar____
/ \ \
FileCookieJar \ \
/ | \ \ \
MozillaCookieJar | LWPCookieJar \ \
| | \
| ---MSIEBase | \
| / | | \
| / MSIEDBCookieJar BSDDBCookieJar
|/
MSIECookieJar
"""
__all__ = ['Cookie', 'CookieJar', 'CookiePolicy', 'DefaultCookiePolicy',
'FileCookieJar', 'LWPCookieJar', 'lwp_cookie_str', 'LoadError',
'MozillaCookieJar']
import re, urlparse, copy, time, urllib
try:
import threading as _threading
except ImportError:
import dummy_threading as _threading
import httplib # only for the default HTTP port
from calendar import timegm
debug = False # set to True to enable debugging via the logging module
logger = None
def _debug(*args):
if not debug:
return
global logger
if not logger:
import logging
logger = logging.getLogger("cookielib")
return logger.debug(*args)
DEFAULT_HTTP_PORT = str(httplib.HTTP_PORT)
MISSING_FILENAME_TEXT = ("a filename was not supplied (nor was the CookieJar "
"instance initialised with one)")
def _warn_unhandled_exception():
# There are a few catch-all except: statements in this module, for
# catching input that's bad in unexpected ways. Warn if any
# exceptions are caught there.
import warnings, traceback, StringIO
f = StringIO.StringIO()
traceback.print_exc(None, f)
msg = f.getvalue()
warnings.warn("cookielib bug!\n%s" % msg, stacklevel=2)
# Date/time conversion
# -----------------------------------------------------------------------------
EPOCH_YEAR = 1970
def _timegm(tt):
year, month, mday, hour, min, sec = tt[:6]
if ((year >= EPOCH_YEAR) and (1 <= month <= 12) and (1 <= mday <= 31) and
(0 <= hour <= 24) and (0 <= min <= 59) and (0 <= sec <= 61)):
return timegm(tt)
else:
return None
DAYS = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
MONTHS = ["Jan", "Feb", "Mar", "Apr", "May", "Jun",
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
MONTHS_LOWER = []
for month in MONTHS: MONTHS_LOWER.append(month.lower())
def time2isoz(t=None):
"""Return a string representing time in seconds since epoch, t.
If the function is called without an argument, it will use the current
time.
The format of the returned string is like "YYYY-MM-DD hh:mm:ssZ",
representing Universal Time (UTC, aka GMT). An example of this format is:
1994-11-24 08:49:37Z
"""
if t is None: t = time.time()
year, mon, mday, hour, min, sec = time.gmtime(t)[:6]
return "%04d-%02d-%02d %02d:%02d:%02dZ" % (
year, mon, mday, hour, min, sec)
def time2netscape(t=None):
"""Return a string representing time in seconds since epoch, t.
If the function is called without an argument, it will use the current
time.
The format of the returned string is like this:
Wed, DD-Mon-YYYY HH:MM:SS GMT
"""
if t is None: t = time.time()
year, mon, mday, hour, min, sec, wday = time.gmtime(t)[:7]
return "%s, %02d-%s-%04d %02d:%02d:%02d GMT" % (
DAYS[wday], mday, MONTHS[mon-1], year, hour, min, sec)
UTC_ZONES = {"GMT": None, "UTC": None, "UT": None, "Z": None}
TIMEZONE_RE = re.compile(r"^([-+])?(\d\d?):?(\d\d)?$")
def offset_from_tz_string(tz):
offset = None
if tz in UTC_ZONES:
offset = 0
else:
m = TIMEZONE_RE.search(tz)
if m:
offset = 3600 * int(m.group(2))
if m.group(3):
offset = offset + 60 * int(m.group(3))
if m.group(1) == '-':
offset = -offset
return offset
def _str2time(day, mon, yr, hr, min, sec, tz):
# translate month name to number
# month numbers start with 1 (January)
try:
mon = MONTHS_LOWER.index(mon.lower())+1
except ValueError:
# maybe it's already a number
try:
imon = int(mon)
except ValueError:
return None
if 1 <= imon <= 12:
mon = imon
else:
return None
# make sure clock elements are defined
if hr is None: hr = 0
if min is None: min = 0
if sec is None: sec = 0
yr = int(yr)
day = int(day)
hr = int(hr)
min = int(min)
sec = int(sec)
if yr < 1000:
# find "obvious" year
cur_yr = time.localtime(time.time())[0]
m = cur_yr % 100
tmp = yr
yr = yr + cur_yr - m
m = m - tmp
if abs(m) > 50:
if m > 0: yr = yr + 100
else: yr = yr - 100
# convert UTC time tuple to seconds since epoch (not timezone-adjusted)
t = _timegm((yr, mon, day, hr, min, sec, tz))
if t is not None:
# adjust time using timezone string, to get absolute time since epoch
if tz is None:
tz = "UTC"
tz = tz.upper()
offset = offset_from_tz_string(tz)
if offset is None:
return None
t = t - offset
return t
STRICT_DATE_RE = re.compile(
r"^[SMTWF][a-z][a-z], (\d\d) ([JFMASOND][a-z][a-z]) "
"(\d\d\d\d) (\d\d):(\d\d):(\d\d) GMT$")
WEEKDAY_RE = re.compile(
r"^(?:Sun|Mon|Tue|Wed|Thu|Fri|Sat)[a-z]*,?\s*", re.I)
LOOSE_HTTP_DATE_RE = re.compile(
r"""^
(\d\d?) # day
(?:\s+|[-\/])
(\w+) # month
(?:\s+|[-\/])
(\d+) # year
(?:
(?:\s+|:) # separator before clock
(\d\d?):(\d\d) # hour:min
(?::(\d\d))? # optional seconds
)? # optional clock
\s*
(?:
([-+]?\d{2,4}|(?![APap][Mm]\b)[A-Za-z]+) # timezone
\s*
)?
(?:
\(\w+\) # ASCII representation of timezone in parens.
\s*
)?$""", re.X)
def http2time(text):
"""Returns time in seconds since epoch of time represented by a string.
Return value is an integer.
None is returned if the format of str is unrecognized, the time is outside
the representable range, or the timezone string is not recognized. If the
string contains no timezone, UTC is assumed.
The timezone in the string may be numerical (like "-0800" or "+0100") or a
string timezone (like "UTC", "GMT", "BST" or "EST"). Currently, only the
timezone strings equivalent to UTC (zero offset) are known to the function.
The function loosely parses the following formats:
Wed, 09 Feb 1994 22:23:32 GMT -- HTTP format
Tuesday, 08-Feb-94 14:15:29 GMT -- old rfc850 HTTP format
Tuesday, 08-Feb-1994 14:15:29 GMT -- broken rfc850 HTTP format
09 Feb 1994 22:23:32 GMT -- HTTP format (no weekday)
08-Feb-94 14:15:29 GMT -- rfc850 format (no weekday)
08-Feb-1994 14:15:29 GMT -- broken rfc850 format (no weekday)
The parser ignores leading and trailing whitespace. The time may be
absent.
If the year is given with only 2 digits, the function will select the
century that makes the year closest to the current date.
"""
# fast exit for strictly conforming string
m = STRICT_DATE_RE.search(text)
if m:
g = m.groups()
mon = MONTHS_LOWER.index(g[1].lower()) + 1
tt = (int(g[2]), mon, int(g[0]),
int(g[3]), int(g[4]), float(g[5]))
return _timegm(tt)
# No, we need some messy parsing...
# clean up
text = text.lstrip()
text = WEEKDAY_RE.sub("", text, 1) # Useless weekday
# tz is time zone specifier string
day, mon, yr, hr, min, sec, tz = [None]*7
# loose regexp parse
m = LOOSE_HTTP_DATE_RE.search(text)
if m is not None:
day, mon, yr, hr, min, sec, tz = m.groups()
else:
return None # bad format
return _str2time(day, mon, yr, hr, min, sec, tz)
ISO_DATE_RE = re.compile(
r"""^
(\d{4}) # year
[-\/]?
(\d\d?) # numerical month
[-\/]?
(\d\d?) # day
(?:
(?:\s+|[-:Tt]) # separator before clock
(\d\d?):?(\d\d) # hour:min
(?::?(\d\d(?:\.\d*)?))? # optional seconds (and fractional)
)? # optional clock
\s*
(?:
([-+]?\d\d?:?(:?\d\d)?
|Z|z) # timezone (Z is "zero meridian", i.e. GMT)
\s*
)?$""", re.X)
def iso2time(text):
"""
As for http2time, but parses the ISO 8601 formats:
1994-02-03 14:15:29 -0100 -- ISO 8601 format
1994-02-03 14:15:29 -- zone is optional
1994-02-03 -- only date
1994-02-03T14:15:29 -- Use T as separator
19940203T141529Z -- ISO 8601 compact format
19940203 -- only date
"""
# clean up
text = text.lstrip()
# tz is time zone specifier string
day, mon, yr, hr, min, sec, tz = [None]*7
# loose regexp parse
m = ISO_DATE_RE.search(text)
if m is not None:
# XXX there's an extra bit of the timezone I'm ignoring here: is
# this the right thing to do?
yr, mon, day, hr, min, sec, tz, _ = m.groups()
else:
return None # bad format
return _str2time(day, mon, yr, hr, min, sec, tz)
# Header parsing
# -----------------------------------------------------------------------------
def unmatched(match):
"""Return unmatched part of re.Match object."""
start, end = match.span(0)
return match.string[:start]+match.string[end:]
HEADER_TOKEN_RE = re.compile(r"^\s*([^=\s;,]+)")
HEADER_QUOTED_VALUE_RE = re.compile(r"^\s*=\s*\"([^\"\\]*(?:\\.[^\"\\]*)*)\"")
HEADER_VALUE_RE = re.compile(r"^\s*=\s*([^\s;,]*)")
HEADER_ESCAPE_RE = re.compile(r"\\(.)")
def split_header_words(header_values):
r"""Parse header values into a list of lists containing key,value pairs.
The function knows how to deal with ",", ";" and "=" as well as quoted
values after "=". A list of space separated tokens are parsed as if they
were separated by ";".
If the header_values passed as argument contains multiple values, then they
are treated as if they were a single value separated by comma ",".
This means that this function is useful for parsing header fields that
follow this syntax (BNF as from the HTTP/1.1 specification, but we relax
the requirement for tokens).
headers = #header
header = (token | parameter) *( [";"] (token | parameter))
token = 1*<any CHAR except CTLs or separators>
separators = "(" | ")" | "<" | ">" | "@"
| "," | ";" | ":" | "\" | <">
| "/" | "[" | "]" | "?" | "="
| "{" | "}" | SP | HT
quoted-string = ( <"> *(qdtext | quoted-pair ) <"> )
qdtext = <any TEXT except <">>
quoted-pair = "\" CHAR
parameter = attribute "=" value
attribute = token
value = token | quoted-string
Each header is represented by a list of key/value pairs. The value for a
simple token (not part of a parameter) is None. Syntactically incorrect
headers will not necessarily be parsed as you would want.
This is easier to describe with some examples:
>>> split_header_words(['foo="bar"; port="80,81"; discard, bar=baz'])
[[('foo', 'bar'), ('port', '80,81'), ('discard', None)], [('bar', 'baz')]]
>>> split_header_words(['text/html; charset="iso-8859-1"'])
[[('text/html', None), ('charset', 'iso-8859-1')]]
>>> split_header_words([r'Basic realm="\"foo\bar\""'])
[[('Basic', None), ('realm', '"foobar"')]]
"""
assert not isinstance(header_values, basestring)
result = []
for text in header_values:
orig_text = text
pairs = []
while text:
m = HEADER_TOKEN_RE.search(text)
if m:
text = unmatched(m)
name = m.group(1)
m = HEADER_QUOTED_VALUE_RE.search(text)
if m: # quoted value
text = unmatched(m)
value = m.group(1)
value = HEADER_ESCAPE_RE.sub(r"\1", value)
else:
m = HEADER_VALUE_RE.search(text)
if m: # unquoted value
text = unmatched(m)
value = m.group(1)
value = value.rstrip()
else:
# no value, a lone token
value = None
pairs.append((name, value))
elif text.lstrip().startswith(","):
# concatenated headers, as per RFC 2616 section 4.2
text = text.lstrip()[1:]
if pairs: result.append(pairs)
pairs = []
else:
# skip junk
non_junk, nr_junk_chars = re.subn("^[=\s;]*", "", text)
assert nr_junk_chars > 0, (
"split_header_words bug: '%s', '%s', %s" %
(orig_text, text, pairs))
text = non_junk
if pairs: result.append(pairs)
return result
HEADER_JOIN_ESCAPE_RE = re.compile(r"([\"\\])")
def join_header_words(lists):
"""Do the inverse (almost) of the conversion done by split_header_words.
Takes a list of lists of (key, value) pairs and produces a single header
value. Attribute values are quoted if needed.
>>> join_header_words([[("text/plain", None), ("charset", "iso-8859/1")]])
'text/plain; charset="iso-8859/1"'
>>> join_header_words([[("text/plain", None)], [("charset", "iso-8859/1")]])
'text/plain, charset="iso-8859/1"'
"""
headers = []
for pairs in lists:
attr = []
for k, v in pairs:
if v is not None:
if not re.search(r"^\w+$", v):
v = HEADER_JOIN_ESCAPE_RE.sub(r"\\\1", v) # escape " and \
v = '"%s"' % v
k = "%s=%s" % (k, v)
attr.append(k)
if attr: headers.append("; ".join(attr))
return ", ".join(headers)
def _strip_quotes(text):
if text.startswith('"'):
text = text[1:]
if text.endswith('"'):
text = text[:-1]
return text
def parse_ns_headers(ns_headers):
"""Ad-hoc parser for Netscape protocol cookie-attributes.
The old Netscape cookie format for Set-Cookie can for instance contain
an unquoted "," in the expires field, so we have to use this ad-hoc
parser instead of split_header_words.
XXX This may not make the best possible effort to parse all the crap
that Netscape Cookie headers contain. Ronald Tschalar's HTTPClient
parser is probably better, so could do worse than following that if
this ever gives any trouble.
Currently, this is also used for parsing RFC 2109 cookies.
"""
known_attrs = ("expires", "domain", "path", "secure",
# RFC 2109 attrs (may turn up in Netscape cookies, too)
"version", "port", "max-age")
result = []
for ns_header in ns_headers:
pairs = []
version_set = False
# XXX: The following does not strictly adhere to RFCs in that empty
# names and values are legal (the former will only appear once and will
# be overwritten if multiple occurrences are present). This is
# mostly to deal with backwards compatibility.
for ii, param in enumerate(ns_header.split(';')):
param = param.strip()
key, sep, val = param.partition('=')
key = key.strip()
if not key:
if ii == 0:
break
else:
continue
# allow for a distinction between present and empty and missing
# altogether
val = val.strip() if sep else None
if ii != 0:
lc = key.lower()
if lc in known_attrs:
key = lc
if key == "version":
# This is an RFC 2109 cookie.
if val is not None:
val = _strip_quotes(val)
version_set = True
elif key == "expires":
# convert expires date to seconds since epoch
if val is not None:
val = http2time(_strip_quotes(val)) # None if invalid
pairs.append((key, val))
if pairs:
if not version_set:
pairs.append(("version", "0"))
result.append(pairs)
return result
IPV4_RE = re.compile(r"\.\d+$")
def is_HDN(text):
"""Return True if text is a host domain name."""
# XXX
# This may well be wrong. Which RFC is HDN defined in, if any (for
# the purposes of RFC 2965)?
# For the current implementation, what about IPv6? Remember to look
# at other uses of IPV4_RE also, if change this.
if IPV4_RE.search(text):
return False
if text == "":
return False
if text[0] == "." or text[-1] == ".":
return False
return True
def domain_match(A, B):
"""Return True if domain A domain-matches domain B, according to RFC 2965.
A and B may be host domain names or IP addresses.
RFC 2965, section 1:
Host names can be specified either as an IP address or a HDN string.
Sometimes we compare one host name with another. (Such comparisons SHALL
be case-insensitive.) Host A's name domain-matches host B's if
* their host name strings string-compare equal; or
* A is a HDN string and has the form NB, where N is a non-empty
name string, B has the form .B', and B' is a HDN string. (So,
x.y.com domain-matches .Y.com but not Y.com.)
Note that domain-match is not a commutative operation: a.b.c.com
domain-matches .c.com, but not the reverse.
"""
# Note that, if A or B are IP addresses, the only relevant part of the
# definition of the domain-match algorithm is the direct string-compare.
A = A.lower()
B = B.lower()
if A == B:
return True
if not is_HDN(A):
return False
i = A.rfind(B)
if i == -1 or i == 0:
# A does not have form NB, or N is the empty string
return False
if not B.startswith("."):
return False
if not is_HDN(B[1:]):
return False
return True
def liberal_is_HDN(text):
"""Return True if text is a sort-of-like a host domain name.
For accepting/blocking domains.
"""
if IPV4_RE.search(text):
return False
return True
def user_domain_match(A, B):
"""For blocking/accepting domains.
A and B may be host domain names or IP addresses.
"""
A = A.lower()
B = B.lower()
if not (liberal_is_HDN(A) and liberal_is_HDN(B)):
if A == B:
# equal IP addresses
return True
return False
initial_dot = B.startswith(".")
if initial_dot and A.endswith(B):
return True
if not initial_dot and A == B:
return True
return False
cut_port_re = re.compile(r":\d+$")
def request_host(request):
"""Return request-host, as defined by RFC 2965.
Variation from RFC: returned value is lowercased, for convenient
comparison.
"""
url = request.get_full_url()
host = urlparse.urlparse(url)[1]
if host == "":
host = request.get_header("Host", "")
# remove port, if present
host = cut_port_re.sub("", host, 1)
return host.lower()
def eff_request_host(request):
"""Return a tuple (request-host, effective request-host name).
As defined by RFC 2965, except both are lowercased.
"""
erhn = req_host = request_host(request)
if req_host.find(".") == -1 and not IPV4_RE.search(req_host):
erhn = req_host + ".local"
return req_host, erhn
def request_path(request):
"""Path component of request-URI, as defined by RFC 2965."""
url = request.get_full_url()
parts = urlparse.urlsplit(url)
path = escape_path(parts.path)
if not path.startswith("/"):
# fix bad RFC 2396 absoluteURI
path = "/" + path
return path
def request_port(request):
host = request.get_host()
i = host.find(':')
if i >= 0:
port = host[i+1:]
try:
int(port)
except ValueError:
_debug("nonnumeric port: '%s'", port)
return None
else:
port = DEFAULT_HTTP_PORT
return port
# Characters in addition to A-Z, a-z, 0-9, '_', '.', and '-' that don't
# need to be escaped to form a valid HTTP URL (RFCs 2396 and 1738).
HTTP_PATH_SAFE = "%/;:@&=+$,!~*'()"
ESCAPED_CHAR_RE = re.compile(r"%([0-9a-fA-F][0-9a-fA-F])")
def uppercase_escaped_char(match):
return "%%%s" % match.group(1).upper()
def escape_path(path):
"""Escape any invalid characters in HTTP URL, and uppercase all escapes."""
# There's no knowing what character encoding was used to create URLs
# containing %-escapes, but since we have to pick one to escape invalid
# path characters, we pick UTF-8, as recommended in the HTML 4.0
# specification:
# http://www.w3.org/TR/REC-html40/appendix/notes.html#h-B.2.1
# And here, kind of: draft-fielding-uri-rfc2396bis-03
# (And in draft IRI specification: draft-duerst-iri-05)
# (And here, for new URI schemes: RFC 2718)
if isinstance(path, unicode):
path = path.encode("utf-8")
path = urllib.quote(path, HTTP_PATH_SAFE)
path = ESCAPED_CHAR_RE.sub(uppercase_escaped_char, path)
return path
def reach(h):
"""Return reach of host h, as defined by RFC 2965, section 1.
The reach R of a host name H is defined as follows:
* If
- H is the host domain name of a host; and,
- H has the form A.B; and
- A has no embedded (that is, interior) dots; and
- B has at least one embedded dot, or B is the string "local".
then the reach of H is .B.
* Otherwise, the reach of H is H.
>>> reach("www.acme.com")
'.acme.com'
>>> reach("acme.com")
'acme.com'
>>> reach("acme.local")
'.local'
"""
i = h.find(".")
if i >= 0:
#a = h[:i] # this line is only here to show what a is
b = h[i+1:]
i = b.find(".")
if is_HDN(h) and (i >= 0 or b == "local"):
return "."+b
return h
def is_third_party(request):
"""
RFC 2965, section 3.3.6:
An unverifiable transaction is to a third-party host if its request-
host U does not domain-match the reach R of the request-host O in the
origin transaction.
"""
req_host = request_host(request)
if not domain_match(req_host, reach(request.get_origin_req_host())):
return True
else:
return False
class Cookie:
"""HTTP Cookie.
This class represents both Netscape and RFC 2965 cookies.
This is deliberately a very simple class. It just holds attributes. It's
possible to construct Cookie instances that don't comply with the cookie
standards. CookieJar.make_cookies is the factory function for Cookie
objects -- it deals with cookie parsing, supplying defaults, and
normalising to the representation used in this class. CookiePolicy is
responsible for checking them to see whether they should be accepted from
and returned to the server.
Note that the port may be present in the headers, but unspecified ("Port"
rather than"Port=80", for example); if this is the case, port is None.
"""
def __init__(self, version, name, value,
port, port_specified,
domain, domain_specified, domain_initial_dot,
path, path_specified,
secure,
expires,
discard,
comment,
comment_url,
rest,
rfc2109=False,
):
if version is not None: version = int(version)
if expires is not None: expires = int(expires)
if port is None and port_specified is True:
raise ValueError("if port is None, port_specified must be false")
self.version = version
self.name = name
self.value = value
self.port = port
self.port_specified = port_specified
# normalise case, as per RFC 2965 section 3.3.3
self.domain = domain.lower()
self.domain_specified = domain_specified
# Sigh. We need to know whether the domain given in the
# cookie-attribute had an initial dot, in order to follow RFC 2965
# (as clarified in draft errata). Needed for the returned $Domain
# value.
self.domain_initial_dot = domain_initial_dot
self.path = path
self.path_specified = path_specified
self.secure = secure
self.expires = expires
self.discard = discard
self.comment = comment
self.comment_url = comment_url
self.rfc2109 = rfc2109
self._rest = copy.copy(rest)
def has_nonstandard_attr(self, name):
return name in self._rest
def get_nonstandard_attr(self, name, default=None):
return self._rest.get(name, default)
def set_nonstandard_attr(self, name, value):
self._rest[name] = value
def is_expired(self, now=None):
if now is None: now = time.time()
if (self.expires is not None) and (self.expires <= now):
return True
return False
def __str__(self):
if self.port is None: p = ""
else: p = ":"+self.port
limit = self.domain + p + self.path
if self.value is not None:
namevalue = "%s=%s" % (self.name, self.value)
else:
namevalue = self.name
return "<Cookie %s for %s>" % (namevalue, limit)
def __repr__(self):
args = []
for name in ("version", "name", "value",
"port", "port_specified",
"domain", "domain_specified", "domain_initial_dot",
"path", "path_specified",
"secure", "expires", "discard", "comment", "comment_url",
):
attr = getattr(self, name)
args.append("%s=%s" % (name, repr(attr)))
args.append("rest=%s" % repr(self._rest))
args.append("rfc2109=%s" % repr(self.rfc2109))
return "Cookie(%s)" % ", ".join(args)
class CookiePolicy:
"""Defines which cookies get accepted from and returned to server.
May also modify cookies, though this is probably a bad idea.
The subclass DefaultCookiePolicy defines the standard rules for Netscape
and RFC 2965 cookies -- override that if you want a customised policy.
"""
def set_ok(self, cookie, request):
"""Return true if (and only if) cookie should be accepted from server.
Currently, pre-expired cookies never get this far -- the CookieJar
class deletes such cookies itself.
"""
raise NotImplementedError()
def return_ok(self, cookie, request):
"""Return true if (and only if) cookie should be returned to server."""
raise NotImplementedError()
def domain_return_ok(self, domain, request):
"""Return false if cookies should not be returned, given cookie domain.
"""
return True
def path_return_ok(self, path, request):
"""Return false if cookies should not be returned, given cookie path.
"""
return True
class DefaultCookiePolicy(CookiePolicy):
"""Implements the standard rules for accepting and returning cookies."""
DomainStrictNoDots = 1
DomainStrictNonDomain = 2
DomainRFC2965Match = 4
DomainLiberal = 0
DomainStrict = DomainStrictNoDots|DomainStrictNonDomain
def __init__(self,
blocked_domains=None, allowed_domains=None,
netscape=True, rfc2965=False,
rfc2109_as_netscape=None,
hide_cookie2=False,
strict_domain=False,
strict_rfc2965_unverifiable=True,
strict_ns_unverifiable=False,
strict_ns_domain=DomainLiberal,
strict_ns_set_initial_dollar=False,
strict_ns_set_path=False,
):
"""Constructor arguments should be passed as keyword arguments only."""
self.netscape = netscape
self.rfc2965 = rfc2965
self.rfc2109_as_netscape = rfc2109_as_netscape
self.hide_cookie2 = hide_cookie2
self.strict_domain = strict_domain
self.strict_rfc2965_unverifiable = strict_rfc2965_unverifiable
self.strict_ns_unverifiable = strict_ns_unverifiable
self.strict_ns_domain = strict_ns_domain
self.strict_ns_set_initial_dollar = strict_ns_set_initial_dollar
self.strict_ns_set_path = strict_ns_set_path
if blocked_domains is not None:
self._blocked_domains = tuple(blocked_domains)
else:
self._blocked_domains = ()
if allowed_domains is not None:
allowed_domains = tuple(allowed_domains)
self._allowed_domains = allowed_domains
def blocked_domains(self):
"""Return the sequence of blocked domains (as a tuple)."""
return self._blocked_domains
def set_blocked_domains(self, blocked_domains):
"""Set the sequence of blocked domains."""
self._blocked_domains = tuple(blocked_domains)
def is_blocked(self, domain):
for blocked_domain in self._blocked_domains:
if user_domain_match(domain, blocked_domain):
return True
return False
def allowed_domains(self):
"""Return None, or the sequence of allowed domains (as a tuple)."""
return self._allowed_domains
def set_allowed_domains(self, allowed_domains):
"""Set the sequence of allowed domains, or None."""
if allowed_domains is not None:
allowed_domains = tuple(allowed_domains)
self._allowed_domains = allowed_domains
def is_not_allowed(self, domain):
if self._allowed_domains is None:
return False
for allowed_domain in self._allowed_domains:
if user_domain_match(domain, allowed_domain):
return False
return True
def set_ok(self, cookie, request):
"""
If you override .set_ok(), be sure to call this method. If it returns
false, so should your subclass (assuming your subclass wants to be more
strict about which cookies to accept).
"""
_debug(" - checking cookie %s=%s", cookie.name, cookie.value)
assert cookie.name is not None
for n in "version", "verifiability", "name", "path", "domain", "port":
fn_name = "set_ok_"+n
fn = getattr(self, fn_name)
if not fn(cookie, request):
return False
return True
def set_ok_version(self, cookie, request):
if cookie.version is None:
# Version is always set to 0 by parse_ns_headers if it's a Netscape
# cookie, so this must be an invalid RFC 2965 cookie.
_debug(" Set-Cookie2 without version attribute (%s=%s)",
cookie.name, cookie.value)
return False
if cookie.version > 0 and not self.rfc2965:
_debug(" RFC 2965 cookies are switched off")
return False
elif cookie.version == 0 and not self.netscape:
_debug(" Netscape cookies are switched off")
return False
return True
def set_ok_verifiability(self, cookie, request):
if request.is_unverifiable() and is_third_party(request):
if cookie.version > 0 and self.strict_rfc2965_unverifiable:
_debug(" third-party RFC 2965 cookie during "
"unverifiable transaction")
return False
elif cookie.version == 0 and self.strict_ns_unverifiable:
_debug(" third-party Netscape cookie during "
"unverifiable transaction")
return False
return True
def set_ok_name(self, cookie, request):
# Try and stop servers setting V0 cookies designed to hack other
# servers that know both V0 and V1 protocols.
if (cookie.version == 0 and self.strict_ns_set_initial_dollar and
cookie.name.startswith("$")):
_debug(" illegal name (starts with '$'): '%s'", cookie.name)
return False
return True
def set_ok_path(self, cookie, request):
if cookie.path_specified:
req_path = request_path(request)
if ((cookie.version > 0 or
(cookie.version == 0 and self.strict_ns_set_path)) and
not self.path_return_ok(cookie.path, request)):
_debug(" path attribute %s is not a prefix of request "
"path %s", cookie.path, req_path)
return False
return True
def set_ok_domain(self, cookie, request):
if self.is_blocked(cookie.domain):
_debug(" domain %s is in user block-list", cookie.domain)
return False
if self.is_not_allowed(cookie.domain):
_debug(" domain %s is not in user allow-list", cookie.domain)
return False
if cookie.domain_specified:
req_host, erhn = eff_request_host(request)
domain = cookie.domain
if self.strict_domain and (domain.count(".") >= 2):
# XXX This should probably be compared with the Konqueror
# (kcookiejar.cpp) and Mozilla implementations, but it's a
# losing battle.
i = domain.rfind(".")
j = domain.rfind(".", 0, i)
if j == 0: # domain like .foo.bar
tld = domain[i+1:]
sld = domain[j+1:i]
if sld.lower() in ("co", "ac", "com", "edu", "org", "net",
"gov", "mil", "int", "aero", "biz", "cat", "coop",
"info", "jobs", "mobi", "museum", "name", "pro",
"travel", "eu") and len(tld) == 2:
# domain like .co.uk
_debug(" country-code second level domain %s", domain)
return False
if domain.startswith("."):
undotted_domain = domain[1:]
else:
undotted_domain = domain
embedded_dots = (undotted_domain.find(".") >= 0)
if not embedded_dots and domain != ".local":
_debug(" non-local domain %s contains no embedded dot",
domain)
return False
if cookie.version == 0:
if (not erhn.endswith(domain) and
(not erhn.startswith(".") and
not ("."+erhn).endswith(domain))):
_debug(" effective request-host %s (even with added "
"initial dot) does not end with %s",
erhn, domain)
return False
if (cookie.version > 0 or
(self.strict_ns_domain & self.DomainRFC2965Match)):
if not domain_match(erhn, domain):
_debug(" effective request-host %s does not domain-match "
"%s", erhn, domain)
return False
if (cookie.version > 0 or
(self.strict_ns_domain & self.DomainStrictNoDots)):
host_prefix = req_host[:-len(domain)]
if (host_prefix.find(".") >= 0 and
not IPV4_RE.search(req_host)):
_debug(" host prefix %s for domain %s contains a dot",
host_prefix, domain)
return False
return True
def set_ok_port(self, cookie, request):
if cookie.port_specified:
req_port = request_port(request)
if req_port is None:
req_port = "80"
else:
req_port = str(req_port)
for p in cookie.port.split(","):
try:
int(p)
except ValueError:
_debug(" bad port %s (not numeric)", p)
return False
if p == req_port:
break
else:
_debug(" request port (%s) not found in %s",
req_port, cookie.port)
return False
return True
def return_ok(self, cookie, request):
"""
If you override .return_ok(), be sure to call this method. If it
returns false, so should your subclass (assuming your subclass wants to
be more strict about which cookies to return).
"""
# Path has already been checked by .path_return_ok(), and domain
# blocking done by .domain_return_ok().
_debug(" - checking cookie %s=%s", cookie.name, cookie.value)
for n in "version", "verifiability", "secure", "expires", "port", "domain":
fn_name = "return_ok_"+n
fn = getattr(self, fn_name)
if not fn(cookie, request):
return False
return True
def return_ok_version(self, cookie, request):
if cookie.version > 0 and not self.rfc2965:
_debug(" RFC 2965 cookies are switched off")
return False
elif cookie.version == 0 and not self.netscape:
_debug(" Netscape cookies are switched off")
return False
return True
def return_ok_verifiability(self, cookie, request):
if request.is_unverifiable() and is_third_party(request):
if cookie.version > 0 and self.strict_rfc2965_unverifiable:
_debug(" third-party RFC 2965 cookie during unverifiable "
"transaction")
return False
elif cookie.version == 0 and self.strict_ns_unverifiable:
_debug(" third-party Netscape cookie during unverifiable "
"transaction")
return False
return True
def return_ok_secure(self, cookie, request):
if cookie.secure and request.get_type() != "https":
_debug(" secure cookie with non-secure request")
return False
return True
def return_ok_expires(self, cookie, request):
if cookie.is_expired(self._now):
_debug(" cookie expired")
return False
return True
def return_ok_port(self, cookie, request):
if cookie.port:
req_port = request_port(request)
if req_port is None:
req_port = "80"
for p in cookie.port.split(","):
if p == req_port:
break
else:
_debug(" request port %s does not match cookie port %s",
req_port, cookie.port)
return False
return True
def return_ok_domain(self, cookie, request):
req_host, erhn = eff_request_host(request)
domain = cookie.domain
if domain and not domain.startswith("."):
dotdomain = "." + domain
else:
dotdomain = domain
# strict check of non-domain cookies: Mozilla does this, MSIE5 doesn't
if (cookie.version == 0 and
(self.strict_ns_domain & self.DomainStrictNonDomain) and
not cookie.domain_specified and domain != erhn):
_debug(" cookie with unspecified domain does not string-compare "
"equal to request domain")
return False
if cookie.version > 0 and not domain_match(erhn, domain):
_debug(" effective request-host name %s does not domain-match "
"RFC 2965 cookie domain %s", erhn, domain)
return False
if cookie.version == 0 and not ("."+erhn).endswith(dotdomain):
_debug(" request-host %s does not match Netscape cookie domain "
"%s", req_host, domain)
return False
return True
def domain_return_ok(self, domain, request):
# Liberal check of. This is here as an optimization to avoid
# having to load lots of MSIE cookie files unless necessary.
req_host, erhn = eff_request_host(request)
if not req_host.startswith("."):
req_host = "."+req_host
if not erhn.startswith("."):
erhn = "."+erhn
if domain and not domain.startswith("."):
dotdomain = "." + domain
else:
dotdomain = domain
if not (req_host.endswith(dotdomain) or erhn.endswith(dotdomain)):
#_debug(" request domain %s does not match cookie domain %s",
# req_host, domain)
return False
if self.is_blocked(domain):
_debug(" domain %s is in user block-list", domain)
return False
if self.is_not_allowed(domain):
_debug(" domain %s is not in user allow-list", domain)
return False
return True
def path_return_ok(self, path, request):
_debug("- checking cookie path=%s", path)
req_path = request_path(request)
pathlen = len(path)
if req_path == path:
return True
elif (req_path.startswith(path) and
(path.endswith("/") or req_path[pathlen:pathlen+1] == "/")):
return True
_debug(" %s does not path-match %s", req_path, path)
return False
def vals_sorted_by_key(adict):
keys = adict.keys()
keys.sort()
return map(adict.get, keys)
def deepvalues(mapping):
"""Iterates over nested mapping, depth-first, in sorted order by key."""
values = vals_sorted_by_key(mapping)
for obj in values:
mapping = False
try:
obj.items
except AttributeError:
pass
else:
mapping = True
for subobj in deepvalues(obj):
yield subobj
if not mapping:
yield obj
# Used as second parameter to dict.get() method, to distinguish absent
# dict key from one with a None value.
class Absent: pass
class CookieJar:
"""Collection of HTTP cookies.
You may not need to know about this class: try
urllib2.build_opener(HTTPCookieProcessor).open(url).
"""
non_word_re = re.compile(r"\W")
quote_re = re.compile(r"([\"\\])")
strict_domain_re = re.compile(r"\.?[^.]*")
domain_re = re.compile(r"[^.]*")
dots_re = re.compile(r"^\.+")
magic_re = r"^\#LWP-Cookies-(\d+\.\d+)"
def __init__(self, policy=None):
if policy is None:
policy = DefaultCookiePolicy()
self._policy = policy
self._cookies_lock = _threading.RLock()
self._cookies = {}
def set_policy(self, policy):
self._policy = policy
def _cookies_for_domain(self, domain, request):
cookies = []
if not self._policy.domain_return_ok(domain, request):
return []
_debug("Checking %s for cookies to return", domain)
cookies_by_path = self._cookies[domain]
for path in cookies_by_path.keys():
if not self._policy.path_return_ok(path, request):
continue
cookies_by_name = cookies_by_path[path]
for cookie in cookies_by_name.values():
if not self._policy.return_ok(cookie, request):
_debug(" not returning cookie")
continue
_debug(" it's a match")
cookies.append(cookie)
return cookies
def _cookies_for_request(self, request):
"""Return a list of cookies to be returned to server."""
cookies = []
for domain in self._cookies.keys():
cookies.extend(self._cookies_for_domain(domain, request))
return cookies
def _cookie_attrs(self, cookies):
"""Return a list of cookie-attributes to be returned to server.
like ['foo="bar"; $Path="/"', ...]
The $Version attribute is also added when appropriate (currently only
once per request).
"""
# add cookies in order of most specific (ie. longest) path first
cookies.sort(key=lambda arg: len(arg.path), reverse=True)
version_set = False
attrs = []
for cookie in cookies:
# set version of Cookie header
# XXX
# What should it be if multiple matching Set-Cookie headers have
# different versions themselves?
# Answer: there is no answer; was supposed to be settled by
# RFC 2965 errata, but that may never appear...
version = cookie.version
if not version_set:
version_set = True
if version > 0:
attrs.append("$Version=%s" % version)
# quote cookie value if necessary
# (not for Netscape protocol, which already has any quotes
# intact, due to the poorly-specified Netscape Cookie: syntax)
if ((cookie.value is not None) and
self.non_word_re.search(cookie.value) and version > 0):
value = self.quote_re.sub(r"\\\1", cookie.value)
else:
value = cookie.value
# add cookie-attributes to be returned in Cookie header
if cookie.value is None:
attrs.append(cookie.name)
else:
attrs.append("%s=%s" % (cookie.name, value))
if version > 0:
if cookie.path_specified:
attrs.append('$Path="%s"' % cookie.path)
if cookie.domain.startswith("."):
domain = cookie.domain
if (not cookie.domain_initial_dot and
domain.startswith(".")):
domain = domain[1:]
attrs.append('$Domain="%s"' % domain)
if cookie.port is not None:
p = "$Port"
if cookie.port_specified:
p = p + ('="%s"' % cookie.port)
attrs.append(p)
return attrs
def add_cookie_header(self, request):
"""Add correct Cookie: header to request (urllib2.Request object).
The Cookie2 header is also added unless policy.hide_cookie2 is true.
"""
_debug("add_cookie_header")
self._cookies_lock.acquire()
try:
self._policy._now = self._now = int(time.time())
cookies = self._cookies_for_request(request)
attrs = self._cookie_attrs(cookies)
if attrs:
if not request.has_header("Cookie"):
request.add_unredirected_header(
"Cookie", "; ".join(attrs))
# if necessary, advertise that we know RFC 2965
if (self._policy.rfc2965 and not self._policy.hide_cookie2 and
not request.has_header("Cookie2")):
for cookie in cookies:
if cookie.version != 1:
request.add_unredirected_header("Cookie2", '$Version="1"')
break
finally:
self._cookies_lock.release()
self.clear_expired_cookies()
def _normalized_cookie_tuples(self, attrs_set):
"""Return list of tuples containing normalised cookie information.
attrs_set is the list of lists of key,value pairs extracted from
the Set-Cookie or Set-Cookie2 headers.
Tuples are name, value, standard, rest, where name and value are the
cookie name and value, standard is a dictionary containing the standard
cookie-attributes (discard, secure, version, expires or max-age,
domain, path and port) and rest is a dictionary containing the rest of
the cookie-attributes.
"""
cookie_tuples = []
boolean_attrs = "discard", "secure"
value_attrs = ("version",
"expires", "max-age",
"domain", "path", "port",
"comment", "commenturl")
for cookie_attrs in attrs_set:
name, value = cookie_attrs[0]
# Build dictionary of standard cookie-attributes (standard) and
# dictionary of other cookie-attributes (rest).
# Note: expiry time is normalised to seconds since epoch. V0
# cookies should have the Expires cookie-attribute, and V1 cookies
# should have Max-Age, but since V1 includes RFC 2109 cookies (and
# since V0 cookies may be a mish-mash of Netscape and RFC 2109), we
# accept either (but prefer Max-Age).
max_age_set = False
bad_cookie = False
standard = {}
rest = {}
for k, v in cookie_attrs[1:]:
lc = k.lower()
# don't lose case distinction for unknown fields
if lc in value_attrs or lc in boolean_attrs:
k = lc
if k in boolean_attrs and v is None:
# boolean cookie-attribute is present, but has no value
# (like "discard", rather than "port=80")
v = True
if k in standard:
# only first value is significant
continue
if k == "domain":
if v is None:
_debug(" missing value for domain attribute")
bad_cookie = True
break
# RFC 2965 section 3.3.3
v = v.lower()
if k == "expires":
if max_age_set:
# Prefer max-age to expires (like Mozilla)
continue
if v is None:
_debug(" missing or invalid value for expires "
"attribute: treating as session cookie")
continue
if k == "max-age":
max_age_set = True
try:
v = int(v)
except ValueError:
_debug(" missing or invalid (non-numeric) value for "
"max-age attribute")
bad_cookie = True
break
# convert RFC 2965 Max-Age to seconds since epoch
# XXX Strictly you're supposed to follow RFC 2616
# age-calculation rules. Remember that zero Max-Age
# is a request to discard (old and new) cookie, though.
k = "expires"
v = self._now + v
if (k in value_attrs) or (k in boolean_attrs):
if (v is None and
k not in ("port", "comment", "commenturl")):
_debug(" missing value for %s attribute" % k)
bad_cookie = True
break
standard[k] = v
else:
rest[k] = v
if bad_cookie:
continue
cookie_tuples.append((name, value, standard, rest))
return cookie_tuples
def _cookie_from_cookie_tuple(self, tup, request):
# standard is dict of standard cookie-attributes, rest is dict of the
# rest of them
name, value, standard, rest = tup
domain = standard.get("domain", Absent)
path = standard.get("path", Absent)
port = standard.get("port", Absent)
expires = standard.get("expires", Absent)
# set the easy defaults
version = standard.get("version", None)
if version is not None:
try:
version = int(version)
except ValueError:
return None # invalid version, ignore cookie
secure = standard.get("secure", False)
# (discard is also set if expires is Absent)
discard = standard.get("discard", False)
comment = standard.get("comment", None)
comment_url = standard.get("commenturl", None)
# set default path
if path is not Absent and path != "":
path_specified = True
path = escape_path(path)
else:
path_specified = False
path = request_path(request)
i = path.rfind("/")
if i != -1:
if version == 0:
# Netscape spec parts company from reality here
path = path[:i]
else:
path = path[:i+1]
if len(path) == 0: path = "/"
# set default domain
domain_specified = domain is not Absent
# but first we have to remember whether it starts with a dot
domain_initial_dot = False
if domain_specified:
domain_initial_dot = bool(domain.startswith("."))
if domain is Absent:
req_host, erhn = eff_request_host(request)
domain = erhn
elif not domain.startswith("."):
domain = "."+domain
# set default port
port_specified = False
if port is not Absent:
if port is None:
# Port attr present, but has no value: default to request port.
# Cookie should then only be sent back on that port.
port = request_port(request)
else:
port_specified = True
port = re.sub(r"\s+", "", port)
else:
# No port attr present. Cookie can be sent back on any port.
port = None
# set default expires and discard
if expires is Absent:
expires = None
discard = True
elif expires <= self._now:
# Expiry date in past is request to delete cookie. This can't be
# in DefaultCookiePolicy, because can't delete cookies there.
try:
self.clear(domain, path, name)
except KeyError:
pass
_debug("Expiring cookie, domain='%s', path='%s', name='%s'",
domain, path, name)
return None
return Cookie(version,
name, value,
port, port_specified,
domain, domain_specified, domain_initial_dot,
path, path_specified,
secure,
expires,
discard,
comment,
comment_url,
rest)
def _cookies_from_attrs_set(self, attrs_set, request):
cookie_tuples = self._normalized_cookie_tuples(attrs_set)
cookies = []
for tup in cookie_tuples:
cookie = self._cookie_from_cookie_tuple(tup, request)
if cookie: cookies.append(cookie)
return cookies
def _process_rfc2109_cookies(self, cookies):
rfc2109_as_ns = getattr(self._policy, 'rfc2109_as_netscape', None)
if rfc2109_as_ns is None:
rfc2109_as_ns = not self._policy.rfc2965
for cookie in cookies:
if cookie.version == 1:
cookie.rfc2109 = True
if rfc2109_as_ns:
# treat 2109 cookies as Netscape cookies rather than
# as RFC2965 cookies
cookie.version = 0
def make_cookies(self, response, request):
"""Return sequence of Cookie objects extracted from response object."""
# get cookie-attributes for RFC 2965 and Netscape protocols
headers = response.info()
rfc2965_hdrs = headers.getheaders("Set-Cookie2")
ns_hdrs = headers.getheaders("Set-Cookie")
rfc2965 = self._policy.rfc2965
netscape = self._policy.netscape
if ((not rfc2965_hdrs and not ns_hdrs) or
(not ns_hdrs and not rfc2965) or
(not rfc2965_hdrs and not netscape) or
(not netscape and not rfc2965)):
return [] # no relevant cookie headers: quick exit
try:
cookies = self._cookies_from_attrs_set(
split_header_words(rfc2965_hdrs), request)
except Exception:
_warn_unhandled_exception()
cookies = []
if ns_hdrs and netscape:
try:
# RFC 2109 and Netscape cookies
ns_cookies = self._cookies_from_attrs_set(
parse_ns_headers(ns_hdrs), request)
except Exception:
_warn_unhandled_exception()
ns_cookies = []
self._process_rfc2109_cookies(ns_cookies)
# Look for Netscape cookies (from Set-Cookie headers) that match
# corresponding RFC 2965 cookies (from Set-Cookie2 headers).
# For each match, keep the RFC 2965 cookie and ignore the Netscape
# cookie (RFC 2965 section 9.1). Actually, RFC 2109 cookies are
# bundled in with the Netscape cookies for this purpose, which is
# reasonable behaviour.
if rfc2965:
lookup = {}
for cookie in cookies:
lookup[(cookie.domain, cookie.path, cookie.name)] = None
def no_matching_rfc2965(ns_cookie, lookup=lookup):
key = ns_cookie.domain, ns_cookie.path, ns_cookie.name
return key not in lookup
ns_cookies = filter(no_matching_rfc2965, ns_cookies)
if ns_cookies:
cookies.extend(ns_cookies)
return cookies
def set_cookie_if_ok(self, cookie, request):
"""Set a cookie if policy says it's OK to do so."""
self._cookies_lock.acquire()
try:
self._policy._now = self._now = int(time.time())
if self._policy.set_ok(cookie, request):
self.set_cookie(cookie)
finally:
self._cookies_lock.release()
def set_cookie(self, cookie):
"""Set a cookie, without checking whether or not it should be set."""
c = self._cookies
self._cookies_lock.acquire()
try:
if cookie.domain not in c: c[cookie.domain] = {}
c2 = c[cookie.domain]
if cookie.path not in c2: c2[cookie.path] = {}
c3 = c2[cookie.path]
c3[cookie.name] = cookie
finally:
self._cookies_lock.release()
def extract_cookies(self, response, request):
"""Extract cookies from response, where allowable given the request."""
_debug("extract_cookies: %s", response.info())
self._cookies_lock.acquire()
try:
self._policy._now = self._now = int(time.time())
for cookie in self.make_cookies(response, request):
if self._policy.set_ok(cookie, request):
_debug(" setting cookie: %s", cookie)
self.set_cookie(cookie)
finally:
self._cookies_lock.release()
def clear(self, domain=None, path=None, name=None):
"""Clear some cookies.
Invoking this method without arguments will clear all cookies. If
given a single argument, only cookies belonging to that domain will be
removed. If given two arguments, cookies belonging to the specified
path within that domain are removed. If given three arguments, then
the cookie with the specified name, path and domain is removed.
Raises KeyError if no matching cookie exists.
"""
if name is not None:
if (domain is None) or (path is None):
raise ValueError(
"domain and path must be given to remove a cookie by name")
del self._cookies[domain][path][name]
elif path is not None:
if domain is None:
raise ValueError(
"domain must be given to remove cookies by path")
del self._cookies[domain][path]
elif domain is not None:
del self._cookies[domain]
else:
self._cookies = {}
def clear_session_cookies(self):
"""Discard all session cookies.
Note that the .save() method won't save session cookies anyway, unless
you ask otherwise by passing a true ignore_discard argument.
"""
self._cookies_lock.acquire()
try:
for cookie in self:
if cookie.discard:
self.clear(cookie.domain, cookie.path, cookie.name)
finally:
self._cookies_lock.release()
def clear_expired_cookies(self):
"""Discard all expired cookies.
You probably don't need to call this method: expired cookies are never
sent back to the server (provided you're using DefaultCookiePolicy),
this method is called by CookieJar itself every so often, and the
.save() method won't save expired cookies anyway (unless you ask
otherwise by passing a true ignore_expires argument).
"""
self._cookies_lock.acquire()
try:
now = time.time()
for cookie in self:
if cookie.is_expired(now):
self.clear(cookie.domain, cookie.path, cookie.name)
finally:
self._cookies_lock.release()
def __iter__(self):
return deepvalues(self._cookies)
def __len__(self):
"""Return number of contained cookies."""
i = 0
for cookie in self: i = i + 1
return i
def __repr__(self):
r = []
for cookie in self: r.append(repr(cookie))
return "<%s[%s]>" % (self.__class__.__name__, ", ".join(r))
def __str__(self):
r = []
for cookie in self: r.append(str(cookie))
return "<%s[%s]>" % (self.__class__.__name__, ", ".join(r))
# derives from IOError for backwards-compatibility with Python 2.4.0
class LoadError(IOError): pass
class FileCookieJar(CookieJar):
"""CookieJar that can be loaded from and saved to a file."""
def __init__(self, filename=None, delayload=False, policy=None):
"""
Cookies are NOT loaded from the named file until either the .load() or
.revert() method is called.
"""
CookieJar.__init__(self, policy)
if filename is not None:
try:
filename+""
except:
raise ValueError("filename must be string-like")
self.filename = filename
self.delayload = bool(delayload)
def save(self, filename=None, ignore_discard=False, ignore_expires=False):
"""Save cookies to a file."""
raise NotImplementedError()
def load(self, filename=None, ignore_discard=False, ignore_expires=False):
"""Load cookies from a file."""
if filename is None:
if self.filename is not None: filename = self.filename
else: raise ValueError(MISSING_FILENAME_TEXT)
f = open(filename)
try:
self._really_load(f, filename, ignore_discard, ignore_expires)
finally:
f.close()
def revert(self, filename=None,
ignore_discard=False, ignore_expires=False):
"""Clear all cookies and reload cookies from a saved file.
Raises LoadError (or IOError) if reversion is not successful; the
object's state will not be altered if this happens.
"""
if filename is None:
if self.filename is not None: filename = self.filename
else: raise ValueError(MISSING_FILENAME_TEXT)
self._cookies_lock.acquire()
try:
old_state = copy.deepcopy(self._cookies)
self._cookies = {}
try:
self.load(filename, ignore_discard, ignore_expires)
except (LoadError, IOError):
self._cookies = old_state
raise
finally:
self._cookies_lock.release()
from _LWPCookieJar import LWPCookieJar, lwp_cookie_str
from _MozillaCookieJar import MozillaCookieJar
"""Generic (shallow and deep) copying operations.
Interface summary:
import copy
x = copy.copy(y) # make a shallow copy of y
x = copy.deepcopy(y) # make a deep copy of y
For module specific errors, copy.Error is raised.
The difference between shallow and deep copying is only relevant for
compound objects (objects that contain other objects, like lists or
class instances).
- A shallow copy constructs a new compound object and then (to the
extent possible) inserts *the same objects* into it that the
original contains.
- A deep copy constructs a new compound object and then, recursively,
inserts *copies* into it of the objects found in the original.
Two problems often exist with deep copy operations that don't exist
with shallow copy operations:
a) recursive objects (compound objects that, directly or indirectly,
contain a reference to themselves) may cause a recursive loop
b) because deep copy copies *everything* it may copy too much, e.g.
administrative data structures that should be shared even between
copies
Python's deep copy operation avoids these problems by:
a) keeping a table of objects already copied during the current
copying pass
b) letting user-defined classes override the copying operation or the
set of components copied
This version does not copy types like module, class, function, method,
nor stack trace, stack frame, nor file, socket, window, nor array, nor
any similar types.
Classes can use the same interfaces to control copying that they use
to control pickling: they can define methods called __getinitargs__(),
__getstate__() and __setstate__(). See the documentation for module
"pickle" for information on these methods.
"""
import types
import weakref
from copy_reg import dispatch_table
class Error(Exception):
pass
error = Error # backward compatibility
try:
from org.python.core import PyStringMap
except ImportError:
PyStringMap = None
__all__ = ["Error", "copy", "deepcopy"]
def copy(x):
"""Shallow copy operation on arbitrary Python objects.
See the module's __doc__ string for more info.
"""
cls = type(x)
copier = _copy_dispatch.get(cls)
if copier:
return copier(x)
copier = getattr(cls, "__copy__", None)
if copier:
return copier(x)
reductor = dispatch_table.get(cls)
if reductor:
rv = reductor(x)
else:
reductor = getattr(x, "__reduce_ex__", None)
if reductor:
rv = reductor(2)
else:
reductor = getattr(x, "__reduce__", None)
if reductor:
rv = reductor()
else:
raise Error("un(shallow)copyable object of type %s" % cls)
return _reconstruct(x, rv, 0)
_copy_dispatch = d = {}
def _copy_immutable(x):
return x
for t in (type(None), int, long, float, bool, str, tuple,
frozenset, type, xrange, types.ClassType,
types.BuiltinFunctionType, type(Ellipsis),
types.FunctionType, weakref.ref):
d[t] = _copy_immutable
for name in ("ComplexType", "UnicodeType", "CodeType"):
t = getattr(types, name, None)
if t is not None:
d[t] = _copy_immutable
def _copy_with_constructor(x):
return type(x)(x)
for t in (list, dict, set):
d[t] = _copy_with_constructor
def _copy_with_copy_method(x):
return x.copy()
if PyStringMap is not None:
d[PyStringMap] = _copy_with_copy_method
def _copy_inst(x):
if hasattr(x, '__copy__'):
return x.__copy__()
if hasattr(x, '__getinitargs__'):
args = x.__getinitargs__()
y = x.__class__(*args)
else:
y = _EmptyClass()
y.__class__ = x.__class__
if hasattr(x, '__getstate__'):
state = x.__getstate__()
else:
state = x.__dict__
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
y.__dict__.update(state)
return y
d[types.InstanceType] = _copy_inst
del d
def deepcopy(x, memo=None, _nil=[]):
"""Deep copy operation on arbitrary Python objects.
See the module's __doc__ string for more info.
"""
if memo is None:
memo = {}
d = id(x)
y = memo.get(d, _nil)
if y is not _nil:
return y
cls = type(x)
copier = _deepcopy_dispatch.get(cls)
if copier:
y = copier(x, memo)
else:
try:
issc = issubclass(cls, type)
except TypeError: # cls is not a class (old Boost; see SF #502085)
issc = 0
if issc:
y = _deepcopy_atomic(x, memo)
else:
copier = getattr(x, "__deepcopy__", None)
if copier:
y = copier(memo)
else:
reductor = dispatch_table.get(cls)
if reductor:
rv = reductor(x)
else:
reductor = getattr(x, "__reduce_ex__", None)
if reductor:
rv = reductor(2)
else:
reductor = getattr(x, "__reduce__", None)
if reductor:
rv = reductor()
else:
raise Error(
"un(deep)copyable object of type %s" % cls)
y = _reconstruct(x, rv, 1, memo)
memo[d] = y
_keep_alive(x, memo) # Make sure x lives at least as long as d
return y
_deepcopy_dispatch = d = {}
def _deepcopy_atomic(x, memo):
return x
d[type(None)] = _deepcopy_atomic
d[type(Ellipsis)] = _deepcopy_atomic
d[int] = _deepcopy_atomic
d[long] = _deepcopy_atomic
d[float] = _deepcopy_atomic
d[bool] = _deepcopy_atomic
try:
d[complex] = _deepcopy_atomic
except NameError:
pass
d[str] = _deepcopy_atomic
try:
d[unicode] = _deepcopy_atomic
except NameError:
pass
try:
d[types.CodeType] = _deepcopy_atomic
except AttributeError:
pass
d[type] = _deepcopy_atomic
d[xrange] = _deepcopy_atomic
d[types.ClassType] = _deepcopy_atomic
d[types.BuiltinFunctionType] = _deepcopy_atomic
d[types.FunctionType] = _deepcopy_atomic
d[weakref.ref] = _deepcopy_atomic
def _deepcopy_list(x, memo):
y = []
memo[id(x)] = y
for a in x:
y.append(deepcopy(a, memo))
return y
d[list] = _deepcopy_list
def _deepcopy_tuple(x, memo):
y = []
for a in x:
y.append(deepcopy(a, memo))
d = id(x)
try:
return memo[d]
except KeyError:
pass
for i in range(len(x)):
if x[i] is not y[i]:
y = tuple(y)
break
else:
y = x
memo[d] = y
return y
d[tuple] = _deepcopy_tuple
def _deepcopy_dict(x, memo):
y = {}
memo[id(x)] = y
for key, value in x.iteritems():
y[deepcopy(key, memo)] = deepcopy(value, memo)
return y
d[dict] = _deepcopy_dict
if PyStringMap is not None:
d[PyStringMap] = _deepcopy_dict
def _deepcopy_method(x, memo): # Copy instance methods
return type(x)(x.im_func, deepcopy(x.im_self, memo), x.im_class)
_deepcopy_dispatch[types.MethodType] = _deepcopy_method
def _keep_alive(x, memo):
"""Keeps a reference to the object x in the memo.
Because we remember objects by their id, we have
to assure that possibly temporary objects are kept
alive by referencing them.
We store a reference at the id of the memo, which should
normally not be used unless someone tries to deepcopy
the memo itself...
"""
try:
memo[id(memo)].append(x)
except KeyError:
# aha, this is the first one :-)
memo[id(memo)]=[x]
def _deepcopy_inst(x, memo):
if hasattr(x, '__deepcopy__'):
return x.__deepcopy__(memo)
if hasattr(x, '__getinitargs__'):
args = x.__getinitargs__()
args = deepcopy(args, memo)
y = x.__class__(*args)
else:
y = _EmptyClass()
y.__class__ = x.__class__
memo[id(x)] = y
if hasattr(x, '__getstate__'):
state = x.__getstate__()
else:
state = x.__dict__
state = deepcopy(state, memo)
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
y.__dict__.update(state)
return y
d[types.InstanceType] = _deepcopy_inst
def _reconstruct(x, info, deep, memo=None):
if isinstance(info, str):
return x
assert isinstance(info, tuple)
if memo is None:
memo = {}
n = len(info)
assert n in (2, 3, 4, 5)
callable, args = info[:2]
if n > 2:
state = info[2]
else:
state = None
if n > 3:
listiter = info[3]
else:
listiter = None
if n > 4:
dictiter = info[4]
else:
dictiter = None
if deep:
args = deepcopy(args, memo)
y = callable(*args)
memo[id(x)] = y
if state is not None:
if deep:
state = deepcopy(state, memo)
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
else:
slotstate = None
if state is not None:
y.__dict__.update(state)
if slotstate is not None:
for key, value in slotstate.iteritems():
setattr(y, key, value)
if listiter is not None:
for item in listiter:
if deep:
item = deepcopy(item, memo)
y.append(item)
if dictiter is not None:
for key, value in dictiter:
if deep:
key = deepcopy(key, memo)
value = deepcopy(value, memo)
y[key] = value
return y
del d
del types
# Helper for instance creation without calling __init__
class _EmptyClass:
pass
def _test():
l = [None, 1, 2L, 3.14, 'xyzzy', (1, 2L), [3.14, 'abc'],
{'abc': 'ABC'}, (), [], {}]
l1 = copy(l)
print l1==l
l1 = map(copy, l)
print l1==l
l1 = deepcopy(l)
print l1==l
class C:
def __init__(self, arg=None):
self.a = 1
self.arg = arg
if __name__ == '__main__':
import sys
file = sys.argv[0]
else:
file = __file__
self.fp = open(file)
self.fp.close()
def __getstate__(self):
return {'a': self.a, 'arg': self.arg}
def __setstate__(self, state):
for key, value in state.iteritems():
setattr(self, key, value)
def __deepcopy__(self, memo=None):
new = self.__class__(deepcopy(self.arg, memo))
new.a = self.a
return new
c = C('argument sketch')
l.append(c)
l2 = copy(l)
print l == l2
print l
print l2
l2 = deepcopy(l)
print l == l2
print l
print l2
l.append({l[1]: l, 'xyz': l[2]})
l3 = copy(l)
import repr
print map(repr.repr, l)
print map(repr.repr, l1)
print map(repr.repr, l2)
print map(repr.repr, l3)
l3 = deepcopy(l)
import repr
print map(repr.repr, l)
print map(repr.repr, l1)
print map(repr.repr, l2)
print map(repr.repr, l3)
class odict(dict):
def __init__(self, d = {}):
self.a = 99
dict.__init__(self, d)
def __setitem__(self, k, i):
dict.__setitem__(self, k, i)
self.a
o = odict({"A" : "B"})
x = deepcopy(o)
print(o, x)
if __name__ == '__main__':
_test()
"""
csv.py - read/write/investigate CSV files
"""
import re
from functools import reduce
from _csv import Error, __version__, writer, reader, register_dialect, \
unregister_dialect, get_dialect, list_dialects, \
field_size_limit, \
QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, \
__doc__
from _csv import Dialect as _Dialect
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
__all__ = [ "QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE",
"Error", "Dialect", "__doc__", "excel", "excel_tab",
"field_size_limit", "reader", "writer",
"register_dialect", "get_dialect", "list_dialects", "Sniffer",
"unregister_dialect", "__version__", "DictReader", "DictWriter" ]
class Dialect:
"""Describe an Excel dialect.
This must be subclassed (see csv.excel). Valid attributes are:
delimiter, quotechar, escapechar, doublequote, skipinitialspace,
lineterminator, quoting.
"""
_name = ""
_valid = False
# placeholders
delimiter = None
quotechar = None
escapechar = None
doublequote = None
skipinitialspace = None
lineterminator = None
quoting = None
def __init__(self):
if self.__class__ != Dialect:
self._valid = True
self._validate()
def _validate(self):
try:
_Dialect(self)
except TypeError, e:
# We do this for compatibility with py2.3
raise Error(str(e))
class excel(Dialect):
"""Describe the usual properties of Excel-generated CSV files."""
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
register_dialect("excel", excel)
class excel_tab(excel):
"""Describe the usual properties of Excel-generated TAB-delimited files."""
delimiter = '\t'
register_dialect("excel-tab", excel_tab)
class DictReader:
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
self.reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self):
return self
@property
def fieldnames(self):
if self._fieldnames is None:
try:
self._fieldnames = self.reader.next()
except StopIteration:
pass
self.line_num = self.reader.line_num
return self._fieldnames
# Issue 20004: Because DictReader is a classic class, this setter is
# ignored. At this point in 2.7's lifecycle, it is too late to change the
# base class for fear of breaking working code. If you want to change
# fieldnames without overwriting the getter, set _fieldnames directly.
@fieldnames.setter
def fieldnames(self, value):
self._fieldnames = value
def next(self):
if self.line_num == 0:
# Used only for its side effect.
self.fieldnames
row = self.reader.next()
self.line_num = self.reader.line_num
# unlike the basic reader, we prefer not to return blanks,
# because we will typically wind up with a dict full of None
# values
while row == []:
row = self.reader.next()
d = dict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
class DictWriter:
def __init__(self, f, fieldnames, restval="", extrasaction="raise",
dialect="excel", *args, **kwds):
self.fieldnames = fieldnames # list of keys for the dict
self.restval = restval # for writing short dicts
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError, \
("extrasaction (%s) must be 'raise' or 'ignore'" %
extrasaction)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self):
header = dict(zip(self.fieldnames, self.fieldnames))
self.writerow(header)
def _dict_to_list(self, rowdict):
if self.extrasaction == "raise":
wrong_fields = [k for k in rowdict if k not in self.fieldnames]
if wrong_fields:
raise ValueError("dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields]))
return [rowdict.get(key, self.restval) for key in self.fieldnames]
def writerow(self, rowdict):
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts):
rows = []
for rowdict in rowdicts:
rows.append(self._dict_to_list(rowdict))
return self.writer.writerows(rows)
# Guard Sniffer's type checking against builds that exclude complex()
try:
complex
except NameError:
complex = float
class Sniffer:
'''
"Sniffs" the format of a CSV file (i.e. delimiter, quotechar)
Returns a Dialect object.
'''
def __init__(self):
# in case there is more than one possible delimiter
self.preferred = [',', '\t', ';', ' ', ':']
def sniff(self, sample, delimiters=None):
"""
Returns a dialect (or None) corresponding to the sample
"""
quotechar, doublequote, delimiter, skipinitialspace = \
self._guess_quote_and_delimiter(sample, delimiters)
if not delimiter:
delimiter, skipinitialspace = self._guess_delimiter(sample,
delimiters)
if not delimiter:
raise Error, "Could not determine delimiter"
class dialect(Dialect):
_name = "sniffed"
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
# escapechar = ''
dialect.doublequote = doublequote
dialect.delimiter = delimiter
# _csv.reader won't accept a quotechar of ''
dialect.quotechar = quotechar or '"'
dialect.skipinitialspace = skipinitialspace
return dialect
def _guess_quote_and_delimiter(self, data, delimiters):
"""
Looks for text enclosed between two identical quotes
(the probable quotechar) which are preceded and followed
by the same character (the probable delimiter).
For example:
,'some text',
The quote with the most wins, same with the delimiter.
If there is no quotechar the delimiter can't be determined
this way.
"""
matches = []
for restr in ('(?P<delim>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?P=delim)', # ,".*?",
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?P<delim>[^\w\n"\'])(?P<space> ?)', # ".*?",
'(?P<delim>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?:$|\n)', # ,".*?"
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?:$|\n)'): # ".*?" (no delim, no space)
regexp = re.compile(restr, re.DOTALL | re.MULTILINE)
matches = regexp.findall(data)
if matches:
break
if not matches:
# (quotechar, doublequote, delimiter, skipinitialspace)
return ('', False, None, 0)
quotes = {}
delims = {}
spaces = 0
for m in matches:
n = regexp.groupindex['quote'] - 1
key = m[n]
if key:
quotes[key] = quotes.get(key, 0) + 1
try:
n = regexp.groupindex['delim'] - 1
key = m[n]
except KeyError:
continue
if key and (delimiters is None or key in delimiters):
delims[key] = delims.get(key, 0) + 1
try:
n = regexp.groupindex['space'] - 1
except KeyError:
continue
if m[n]:
spaces += 1
quotechar = reduce(lambda a, b, quotes = quotes:
(quotes[a] > quotes[b]) and a or b, quotes.keys())
if delims:
delim = reduce(lambda a, b, delims = delims:
(delims[a] > delims[b]) and a or b, delims.keys())
skipinitialspace = delims[delim] == spaces
if delim == '\n': # most likely a file with a single column
delim = ''
else:
# there is *no* delimiter, it's a single column of quoted data
delim = ''
skipinitialspace = 0
# if we see an extra quote between delimiters, we've got a
# double quoted format
dq_regexp = re.compile(
r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)" % \
{'delim':re.escape(delim), 'quote':quotechar}, re.MULTILINE)
if dq_regexp.search(data):
doublequote = True
else:
doublequote = False
return (quotechar, doublequote, delim, skipinitialspace)
def _guess_delimiter(self, data, delimiters):
"""
The delimiter /should/ occur the same number of times on
each row. However, due to malformed data, it may not. We don't want
an all or nothing approach, so we allow for small variations in this
number.
1) build a table of the frequency of each character on every line.
2) build a table of frequencies of this frequency (meta-frequency?),
e.g. 'x occurred 5 times in 10 rows, 6 times in 1000 rows,
7 times in 2 rows'
3) use the mode of the meta-frequency to determine the /expected/
frequency for that character
4) find out how often the character actually meets that goal
5) the character that best meets its goal is the delimiter
For performance reasons, the data is evaluated in chunks, so it can
try and evaluate the smallest portion of the data possible, evaluating
additional chunks as necessary.
"""
data = filter(None, data.split('\n'))
ascii = [chr(c) for c in range(127)] # 7-bit ASCII
# build frequency tables
chunkLength = min(10, len(data))
iteration = 0
charFrequency = {}
modes = {}
delims = {}
start, end = 0, min(chunkLength, len(data))
while start < len(data):
iteration += 1
for line in data[start:end]:
for char in ascii:
metaFrequency = charFrequency.get(char, {})
# must count even if frequency is 0
freq = line.count(char)
# value is the mode
metaFrequency[freq] = metaFrequency.get(freq, 0) + 1
charFrequency[char] = metaFrequency
for char in charFrequency.keys():
items = charFrequency[char].items()
if len(items) == 1 and items[0][0] == 0:
continue
# get the mode of the frequencies
if len(items) > 1:
modes[char] = reduce(lambda a, b: a[1] > b[1] and a or b,
items)
# adjust the mode - subtract the sum of all
# other frequencies
items.remove(modes[char])
modes[char] = (modes[char][0], modes[char][1]
- reduce(lambda a, b: (0, a[1] + b[1]),
items)[1])
else:
modes[char] = items[0]
# build a list of possible delimiters
modeList = modes.items()
total = float(chunkLength * iteration)
# (rows of consistent data) / (number of rows) = 100%
consistency = 1.0
# minimum consistency threshold
threshold = 0.9
while len(delims) == 0 and consistency >= threshold:
for k, v in modeList:
if v[0] > 0 and v[1] > 0:
if ((v[1]/total) >= consistency and
(delimiters is None or k in delimiters)):
delims[k] = v
consistency -= 0.01
if len(delims) == 1:
delim = delims.keys()[0]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
# analyze another chunkLength lines
start = end
end += chunkLength
if not delims:
return ('', 0)
# if there's more than one, fall back to a 'preferred' list
if len(delims) > 1:
for d in self.preferred:
if d in delims.keys():
skipinitialspace = (data[0].count(d) ==
data[0].count("%c " % d))
return (d, skipinitialspace)
# nothing else indicates a preference, pick the character that
# dominates(?)
items = [(v,k) for (k,v) in delims.items()]
items.sort()
delim = items[-1][1]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
def has_header(self, sample):
# Creates a dictionary of types of data in each column. If any
# column is of a single type (say, integers), *except* for the first
# row, then the first row is presumed to be labels. If the type
# can't be determined, it is assumed to be a string in which case
# the length of the string is the determining factor: if all of the
# rows except for the first are the same length, it's a header.
# Finally, a 'vote' is taken at the end for each column, adding or
# subtracting from the likelihood of the first row being a header.
rdr = reader(StringIO(sample), self.sniff(sample))
header = rdr.next() # assume first row is header
columns = len(header)
columnTypes = {}
for i in range(columns): columnTypes[i] = None
checked = 0
for row in rdr:
# arbitrary number of rows to check, to keep it sane
if checked > 20:
break
checked += 1
if len(row) != columns:
continue # skip rows that have irregular number of columns
for col in columnTypes.keys():
for thisType in [int, long, float, complex]:
try:
thisType(row[col])
break
except (ValueError, OverflowError):
pass
else:
# fallback to length of string
thisType = len(row[col])
# treat longs as ints
if thisType == long:
thisType = int
if thisType != columnTypes[col]:
if columnTypes[col] is None: # add new column type
columnTypes[col] = thisType
else:
# type is inconsistent, remove column from
# consideration
del columnTypes[col]
# finally, compare results against first row and "vote"
# on whether it's a header
hasHeader = 0
for col, colType in columnTypes.items():
if type(colType) == type(0): # it's a length
if len(header[col]) != colType:
hasHeader += 1
else:
hasHeader -= 1
else: # attempt typecast
try:
colType(header[col])
except (ValueError, TypeError):
hasHeader += 1
else:
hasHeader -= 1
return hasHeader > 0
"""
dyld emulation
"""
import os
from framework import framework_info
from dylib import dylib_info
from itertools import *
__all__ = [
'dyld_find', 'framework_find',
'framework_info', 'dylib_info',
]
# These are the defaults as per man dyld(1)
#
DEFAULT_FRAMEWORK_FALLBACK = [
os.path.expanduser("~/Library/Frameworks"),
"/Library/Frameworks",
"/Network/Library/Frameworks",
"/System/Library/Frameworks",
]
DEFAULT_LIBRARY_FALLBACK = [
os.path.expanduser("~/lib"),
"/usr/local/lib",
"/lib",
"/usr/lib",
]
def ensure_utf8(s):
"""Not all of PyObjC and Python understand unicode paths very well yet"""
if isinstance(s, unicode):
return s.encode('utf8')
return s
def dyld_env(env, var):
if env is None:
env = os.environ
rval = env.get(var)
if rval is None:
return []
return rval.split(':')
def dyld_image_suffix(env=None):
if env is None:
env = os.environ
return env.get('DYLD_IMAGE_SUFFIX')
def dyld_framework_path(env=None):
return dyld_env(env, 'DYLD_FRAMEWORK_PATH')
def dyld_library_path(env=None):
return dyld_env(env, 'DYLD_LIBRARY_PATH')
def dyld_fallback_framework_path(env=None):
return dyld_env(env, 'DYLD_FALLBACK_FRAMEWORK_PATH')
def dyld_fallback_library_path(env=None):
return dyld_env(env, 'DYLD_FALLBACK_LIBRARY_PATH')
def dyld_image_suffix_search(iterator, env=None):
"""For a potential path iterator, add DYLD_IMAGE_SUFFIX semantics"""
suffix = dyld_image_suffix(env)
if suffix is None:
return iterator
def _inject(iterator=iterator, suffix=suffix):
for path in iterator:
if path.endswith('.dylib'):
yield path[:-len('.dylib')] + suffix + '.dylib'
else:
yield path + suffix
yield path
return _inject()
def dyld_override_search(name, env=None):
# If DYLD_FRAMEWORK_PATH is set and this dylib_name is a
# framework name, use the first file that exists in the framework
# path if any. If there is none go on to search the DYLD_LIBRARY_PATH
# if any.
framework = framework_info(name)
if framework is not None:
for path in dyld_framework_path(env):
yield os.path.join(path, framework['name'])
# If DYLD_LIBRARY_PATH is set then use the first file that exists
# in the path. If none use the original name.
for path in dyld_library_path(env):
yield os.path.join(path, os.path.basename(name))
def dyld_executable_path_search(name, executable_path=None):
# If we haven't done any searching and found a library and the
# dylib_name starts with "@executable_path/" then construct the
# library name.
if name.startswith('@executable_path/') and executable_path is not None:
yield os.path.join(executable_path, name[len('@executable_path/'):])
def dyld_default_search(name, env=None):
yield name
framework = framework_info(name)
if framework is not None:
fallback_framework_path = dyld_fallback_framework_path(env)
for path in fallback_framework_path:
yield os.path.join(path, framework['name'])
fallback_library_path = dyld_fallback_library_path(env)
for path in fallback_library_path:
yield os.path.join(path, os.path.basename(name))
if framework is not None and not fallback_framework_path:
for path in DEFAULT_FRAMEWORK_FALLBACK:
yield os.path.join(path, framework['name'])
if not fallback_library_path:
for path in DEFAULT_LIBRARY_FALLBACK:
yield os.path.join(path, os.path.basename(name))
def dyld_find(name, executable_path=None, env=None):
"""
Find a library or framework using dyld semantics
"""
name = ensure_utf8(name)
executable_path = ensure_utf8(executable_path)
for path in dyld_image_suffix_search(chain(
dyld_override_search(name, env),
dyld_executable_path_search(name, executable_path),
dyld_default_search(name, env),
), env):
if os.path.isfile(path):
return path
raise ValueError("dylib %s could not be found" % (name,))
def framework_find(fn, executable_path=None, env=None):
"""
Find a framework using dyld semantics in a very loose manner.
Will take input such as:
Python
Python.framework
Python.framework/Versions/Current
"""
try:
return dyld_find(fn, executable_path=executable_path, env=env)
except ValueError, e:
pass
fmwk_index = fn.rfind('.framework')
if fmwk_index == -1:
fmwk_index = len(fn)
fn += '.framework'
fn = os.path.join(fn, os.path.basename(fn[:fmwk_index]))
try:
return dyld_find(fn, executable_path=executable_path, env=env)
except ValueError:
raise e
def test_dyld_find():
env = {}
assert dyld_find('libSystem.dylib') == '/usr/lib/libSystem.dylib'
assert dyld_find('System.framework/System') == '/System/Library/Frameworks/System.framework/System'
if __name__ == '__main__':
test_dyld_find()
"""
Generic dylib path manipulation
"""
import re
__all__ = ['dylib_info']
DYLIB_RE = re.compile(r"""(?x)
(?P<location>^.*)(?:^|/)
(?P<name>
(?P<shortname>\w+?)
(?:\.(?P<version>[^._]+))?
(?:_(?P<suffix>[^._]+))?
\.dylib$
)
""")
def dylib_info(filename):
"""
A dylib name can take one of the following four forms:
Location/Name.SomeVersion_Suffix.dylib
Location/Name.SomeVersion.dylib
Location/Name_Suffix.dylib
Location/Name.dylib
returns None if not found or a mapping equivalent to:
dict(
location='Location',
name='Name.SomeVersion_Suffix.dylib',
shortname='Name',
version='SomeVersion',
suffix='Suffix',
)
Note that SomeVersion and Suffix are optional and may be None
if not present.
"""
is_dylib = DYLIB_RE.match(filename)
if not is_dylib:
return None
return is_dylib.groupdict()
def test_dylib_info():
def d(location=None, name=None, shortname=None, version=None, suffix=None):
return dict(
location=location,
name=name,
shortname=shortname,
version=version,
suffix=suffix
)
assert dylib_info('completely/invalid') is None
assert dylib_info('completely/invalide_debug') is None
assert dylib_info('P/Foo.dylib') == d('P', 'Foo.dylib', 'Foo')
assert dylib_info('P/Foo_debug.dylib') == d('P', 'Foo_debug.dylib', 'Foo', suffix='debug')
assert dylib_info('P/Foo.A.dylib') == d('P', 'Foo.A.dylib', 'Foo', 'A')
assert dylib_info('P/Foo_debug.A.dylib') == d('P', 'Foo_debug.A.dylib', 'Foo_debug', 'A')
assert dylib_info('P/Foo.A_debug.dylib') == d('P', 'Foo.A_debug.dylib', 'Foo', 'A', 'debug')
if __name__ == '__main__':
test_dylib_info()
"""
Generic framework path manipulation
"""
import re
__all__ = ['framework_info']
STRICT_FRAMEWORK_RE = re.compile(r"""(?x)
(?P<location>^.*)(?:^|/)
(?P<name>
(?P<shortname>\w+).framework/
(?:Versions/(?P<version>[^/]+)/)?
(?P=shortname)
(?:_(?P<suffix>[^_]+))?
)$
""")
def framework_info(filename):
"""
A framework name can take one of the following four forms:
Location/Name.framework/Versions/SomeVersion/Name_Suffix
Location/Name.framework/Versions/SomeVersion/Name
Location/Name.framework/Name_Suffix
Location/Name.framework/Name
returns None if not found, or a mapping equivalent to:
dict(
location='Location',
name='Name.framework/Versions/SomeVersion/Name_Suffix',
shortname='Name',
version='SomeVersion',
suffix='Suffix',
)
Note that SomeVersion and Suffix are optional and may be None
if not present
"""
is_framework = STRICT_FRAMEWORK_RE.match(filename)
if not is_framework:
return None
return is_framework.groupdict()
def test_framework_info():
def d(location=None, name=None, shortname=None, version=None, suffix=None):
return dict(
location=location,
name=name,
shortname=shortname,
version=version,
suffix=suffix
)
assert framework_info('completely/invalid') is None
assert framework_info('completely/invalid/_debug') is None
assert framework_info('P/F.framework') is None
assert framework_info('P/F.framework/_debug') is None
assert framework_info('P/F.framework/F') == d('P', 'F.framework/F', 'F')
assert framework_info('P/F.framework/F_debug') == d('P', 'F.framework/F_debug', 'F', suffix='debug')
assert framework_info('P/F.framework/Versions') is None
assert framework_info('P/F.framework/Versions/A') is None
assert framework_info('P/F.framework/Versions/A/F') == d('P', 'F.framework/Versions/A/F', 'F', 'A')
assert framework_info('P/F.framework/Versions/A/F_debug') == d('P', 'F.framework/Versions/A/F_debug', 'F', 'A', 'debug')
if __name__ == '__main__':
test_framework_info()
"""
Enough Mach-O to make your head spin.
See the relevant header files in /usr/include/mach-o
And also Apple's documentation.
"""
__version__ = '1.0'
import os
import subprocess
import sys
# find_library(name) returns the pathname of a library, or None.
if os.name == "nt":
def _get_build_version():
"""Return the version of MSVC that was used to build Python.
For Python 2.3 and up, the version number is included in
sys.version. For earlier versions, assume the compiler is MSVC 6.
"""
# This function was copied from Lib/distutils/msvccompiler.py
prefix = "MSC v."
i = sys.version.find(prefix)
if i == -1:
return 6
i = i + len(prefix)
s, rest = sys.version[i:].split(" ", 1)
majorVersion = int(s[:-2]) - 6
minorVersion = int(s[2:3]) / 10.0
# I don't think paths are affected by minor version in version 6
if majorVersion == 6:
minorVersion = 0
if majorVersion >= 6:
return majorVersion + minorVersion
# else we don't know what version of the compiler this is
return None
def find_msvcrt():
"""Return the name of the VC runtime dll"""
version = _get_build_version()
if version is None:
# better be safe than sorry
return None
if version <= 6:
clibname = 'msvcrt'
else:
clibname = 'msvcr%d' % (version * 10)
# If python was built with in debug mode
import imp
if imp.get_suffixes()[0][0] == '_d.pyd':
clibname += 'd'
return clibname+'.dll'
def find_library(name):
if name in ('c', 'm'):
return find_msvcrt()
# See MSDN for the REAL search order.
for directory in os.environ['PATH'].split(os.pathsep):
fname = os.path.join(directory, name)
if os.path.isfile(fname):
return fname
if fname.lower().endswith(".dll"):
continue
fname = fname + ".dll"
if os.path.isfile(fname):
return fname
return None
if os.name == "ce":
# search path according to MSDN:
# - absolute path specified by filename
# - The .exe launch directory
# - the Windows directory
# - ROM dll files (where are they?)
# - OEM specified search path: HKLM\Loader\SystemPath
def find_library(name):
return name
if os.name == "posix" and sys.platform == "darwin":
from ctypes.macholib.dyld import dyld_find as _dyld_find
def find_library(name):
possible = ['lib%s.dylib' % name,
'%s.dylib' % name,
'%s.framework/%s' % (name, name)]
for name in possible:
try:
return _dyld_find(name)
except ValueError:
continue
return None
elif os.name == "posix":
# Andreas Degert's find functions, using gcc, /sbin/ldconfig, objdump
import re, tempfile, errno
def _findLib_gcc(name):
# Run GCC's linker with the -t (aka --trace) option and examine the
# library name it prints out. The GCC command will fail because we
# haven't supplied a proper program with main(), but that does not
# matter.
expr = r'[^\(\)\s]*lib%s\.[^\(\)\s]*' % re.escape(name)
cmd = 'if type gcc >/dev/null 2>&1; then CC=gcc; elif type cc >/dev/null 2>&1; then CC=cc;else exit; fi;' \
'LANG=C LC_ALL=C $CC -Wl,-t -o "$2" 2>&1 -l"$1"'
temp = tempfile.NamedTemporaryFile()
try:
proc = subprocess.Popen((cmd, '_findLib_gcc', name, temp.name),
shell=True,
stdout=subprocess.PIPE)
[trace, _] = proc.communicate()
finally:
try:
temp.close()
except OSError, e:
# ENOENT is raised if the file was already removed, which is
# the normal behaviour of GCC if linking fails
if e.errno != errno.ENOENT:
raise
res = re.search(expr, trace)
if not res:
return None
return res.group(0)
if sys.platform == "sunos5":
# use /usr/ccs/bin/dump on solaris
def _get_soname(f):
if not f:
return None
null = open(os.devnull, "wb")
try:
with null:
proc = subprocess.Popen(("/usr/ccs/bin/dump", "-Lpv", f),
stdout=subprocess.PIPE,
stderr=null)
except OSError: # E.g. command not found
return None
[data, _] = proc.communicate()
res = re.search(br'\[.*\]\sSONAME\s+([^\s]+)', data)
if not res:
return None
return res.group(1)
else:
def _get_soname(f):
# assuming GNU binutils / ELF
if not f:
return None
cmd = 'if ! type objdump >/dev/null 2>&1; then exit; fi;' \
'objdump -p -j .dynamic 2>/dev/null "$1"'
proc = subprocess.Popen((cmd, '_get_soname', f), shell=True,
stdout=subprocess.PIPE)
[dump, _] = proc.communicate()
res = re.search(br'\sSONAME\s+([^\s]+)', dump)
if not res:
return None
return res.group(1)
if (sys.platform.startswith("freebsd")
or sys.platform.startswith("openbsd")
or sys.platform.startswith("dragonfly")):
def _num_version(libname):
# "libxyz.so.MAJOR.MINOR" => [ MAJOR, MINOR ]
parts = libname.split(b".")
nums = []
try:
while parts:
nums.insert(0, int(parts.pop()))
except ValueError:
pass
return nums or [sys.maxint]
def find_library(name):
ename = re.escape(name)
expr = r':-l%s\.\S+ => \S*/(lib%s\.\S+)' % (ename, ename)
null = open(os.devnull, 'wb')
try:
with null:
proc = subprocess.Popen(('/sbin/ldconfig', '-r'),
stdout=subprocess.PIPE,
stderr=null)
except OSError: # E.g. command not found
data = b''
else:
[data, _] = proc.communicate()
res = re.findall(expr, data)
if not res:
return _get_soname(_findLib_gcc(name))
res.sort(key=_num_version)
return res[-1]
elif sys.platform == "sunos5":
def _findLib_crle(name, is64):
if not os.path.exists('/usr/bin/crle'):
return None
env = dict(os.environ)
env['LC_ALL'] = 'C'
if is64:
args = ('/usr/bin/crle', '-64')
else:
args = ('/usr/bin/crle',)
paths = None
null = open(os.devnull, 'wb')
try:
with null:
proc = subprocess.Popen(args,
stdout=subprocess.PIPE,
stderr=null,
env=env)
except OSError: # E.g. bad executable
return None
try:
for line in proc.stdout:
line = line.strip()
if line.startswith(b'Default Library Path (ELF):'):
paths = line.split()[4]
finally:
proc.stdout.close()
proc.wait()
if not paths:
return None
for dir in paths.split(":"):
libfile = os.path.join(dir, "lib%s.so" % name)
if os.path.exists(libfile):
return libfile
return None
def find_library(name, is64 = False):
return _get_soname(_findLib_crle(name, is64) or _findLib_gcc(name))
else:
def _findSoname_ldconfig(name):
import struct
if struct.calcsize('l') == 4:
machine = os.uname()[4] + '-32'
else:
machine = os.uname()[4] + '-64'
mach_map = {
'x86_64-64': 'libc6,x86-64',
'ppc64-64': 'libc6,64bit',
'sparc64-64': 'libc6,64bit',
's390x-64': 'libc6,64bit',
'ia64-64': 'libc6,IA-64',
}
abi_type = mach_map.get(machine, 'libc6')
# XXX assuming GLIBC's ldconfig (with option -p)
expr = r'\s+(lib%s\.[^\s]+)\s+\(%s' % (re.escape(name), abi_type)
env = dict(os.environ)
env['LC_ALL'] = 'C'
env['LANG'] = 'C'
null = open(os.devnull, 'wb')
try:
with null:
p = subprocess.Popen(['/sbin/ldconfig', '-p'],
stderr=null,
stdout=subprocess.PIPE,
env=env)
except OSError: # E.g. command not found
return None
[data, _] = p.communicate()
res = re.search(expr, data)
if not res:
return None
return res.group(1)
def find_library(name):
return _findSoname_ldconfig(name) or _get_soname(_findLib_gcc(name))
################################################################
# test code
def test():
from ctypes import cdll
if os.name == "nt":
print cdll.msvcrt
print cdll.load("msvcrt")
print find_library("msvcrt")
if os.name == "posix":
# find and load_version
print find_library("m")
print find_library("c")
print find_library("bz2")
# getattr
## print cdll.m
## print cdll.bz2
# load
if sys.platform == "darwin":
print cdll.LoadLibrary("libm.dylib")
print cdll.LoadLibrary("libcrypto.dylib")
print cdll.LoadLibrary("libSystem.dylib")
print cdll.LoadLibrary("System.framework/System")
else:
print cdll.LoadLibrary("libm.so")
print cdll.LoadLibrary("libcrypt.so")
print find_library("crypt")
if __name__ == "__main__":
test()
# The most useful windows datatypes
from ctypes import *
BYTE = c_byte
WORD = c_ushort
DWORD = c_ulong
WCHAR = c_wchar
UINT = c_uint
INT = c_int
DOUBLE = c_double
FLOAT = c_float
BOOLEAN = BYTE
BOOL = c_long
from ctypes import _SimpleCData
class VARIANT_BOOL(_SimpleCData):
_type_ = "v"
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.value)
ULONG = c_ulong
LONG = c_long
USHORT = c_ushort
SHORT = c_short
# in the windows header files, these are structures.
_LARGE_INTEGER = LARGE_INTEGER = c_longlong
_ULARGE_INTEGER = ULARGE_INTEGER = c_ulonglong
LPCOLESTR = LPOLESTR = OLESTR = c_wchar_p
LPCWSTR = LPWSTR = c_wchar_p
LPCSTR = LPSTR = c_char_p
LPCVOID = LPVOID = c_void_p
# WPARAM is defined as UINT_PTR (unsigned type)
# LPARAM is defined as LONG_PTR (signed type)
if sizeof(c_long) == sizeof(c_void_p):
WPARAM = c_ulong
LPARAM = c_long
elif sizeof(c_longlong) == sizeof(c_void_p):
WPARAM = c_ulonglong
LPARAM = c_longlong
ATOM = WORD
LANGID = WORD
COLORREF = DWORD
LGRPID = DWORD
LCTYPE = DWORD
LCID = DWORD
################################################################
# HANDLE types
HANDLE = c_void_p # in the header files: void *
HACCEL = HANDLE
HBITMAP = HANDLE
HBRUSH = HANDLE
HCOLORSPACE = HANDLE
HDC = HANDLE
HDESK = HANDLE
HDWP = HANDLE
HENHMETAFILE = HANDLE
HFONT = HANDLE
HGDIOBJ = HANDLE
HGLOBAL = HANDLE
HHOOK = HANDLE
HICON = HANDLE
HINSTANCE = HANDLE
HKEY = HANDLE
HKL = HANDLE
HLOCAL = HANDLE
HMENU = HANDLE
HMETAFILE = HANDLE
HMODULE = HANDLE
HMONITOR = HANDLE
HPALETTE = HANDLE
HPEN = HANDLE
HRGN = HANDLE
HRSRC = HANDLE
HSTR = HANDLE
HTASK = HANDLE
HWINSTA = HANDLE
HWND = HANDLE
SC_HANDLE = HANDLE
SERVICE_STATUS_HANDLE = HANDLE
################################################################
# Some important structure definitions
class RECT(Structure):
_fields_ = [("left", c_long),
("top", c_long),
("right", c_long),
("bottom", c_long)]
tagRECT = _RECTL = RECTL = RECT
class _SMALL_RECT(Structure):
_fields_ = [('Left', c_short),
('Top', c_short),
('Right', c_short),
('Bottom', c_short)]
SMALL_RECT = _SMALL_RECT
class _COORD(Structure):
_fields_ = [('X', c_short),
('Y', c_short)]
class POINT(Structure):
_fields_ = [("x", c_long),
("y", c_long)]
tagPOINT = _POINTL = POINTL = POINT
class SIZE(Structure):
_fields_ = [("cx", c_long),
("cy", c_long)]
tagSIZE = SIZEL = SIZE
def RGB(red, green, blue):
return red + (green << 8) + (blue << 16)
class FILETIME(Structure):
_fields_ = [("dwLowDateTime", DWORD),
("dwHighDateTime", DWORD)]
_FILETIME = FILETIME
class MSG(Structure):
_fields_ = [("hWnd", HWND),
("message", c_uint),
("wParam", WPARAM),
("lParam", LPARAM),
("time", DWORD),
("pt", POINT)]
tagMSG = MSG
MAX_PATH = 260
class WIN32_FIND_DATAA(Structure):
_fields_ = [("dwFileAttributes", DWORD),
("ftCreationTime", FILETIME),
("ftLastAccessTime", FILETIME),
("ftLastWriteTime", FILETIME),
("nFileSizeHigh", DWORD),
("nFileSizeLow", DWORD),
("dwReserved0", DWORD),
("dwReserved1", DWORD),
("cFileName", c_char * MAX_PATH),
("cAlternateFileName", c_char * 14)]
class WIN32_FIND_DATAW(Structure):
_fields_ = [("dwFileAttributes", DWORD),
("ftCreationTime", FILETIME),
("ftLastAccessTime", FILETIME),
("ftLastWriteTime", FILETIME),
("nFileSizeHigh", DWORD),
("nFileSizeLow", DWORD),
("dwReserved0", DWORD),
("dwReserved1", DWORD),
("cFileName", c_wchar * MAX_PATH),
("cAlternateFileName", c_wchar * 14)]
__all__ = ['ATOM', 'BOOL', 'BOOLEAN', 'BYTE', 'COLORREF', 'DOUBLE', 'DWORD',
'FILETIME', 'FLOAT', 'HACCEL', 'HANDLE', 'HBITMAP', 'HBRUSH',
'HCOLORSPACE', 'HDC', 'HDESK', 'HDWP', 'HENHMETAFILE', 'HFONT',
'HGDIOBJ', 'HGLOBAL', 'HHOOK', 'HICON', 'HINSTANCE', 'HKEY',
'HKL', 'HLOCAL', 'HMENU', 'HMETAFILE', 'HMODULE', 'HMONITOR',
'HPALETTE', 'HPEN', 'HRGN', 'HRSRC', 'HSTR', 'HTASK', 'HWINSTA',
'HWND', 'INT', 'LANGID', 'LARGE_INTEGER', 'LCID', 'LCTYPE',
'LGRPID', 'LONG', 'LPARAM', 'LPCOLESTR', 'LPCSTR', 'LPCVOID',
'LPCWSTR', 'LPOLESTR', 'LPSTR', 'LPVOID', 'LPWSTR', 'MAX_PATH',
'MSG', 'OLESTR', 'POINT', 'POINTL', 'RECT', 'RECTL', 'RGB',
'SC_HANDLE', 'SERVICE_STATUS_HANDLE', 'SHORT', 'SIZE', 'SIZEL',
'SMALL_RECT', 'UINT', 'ULARGE_INTEGER', 'ULONG', 'USHORT',
'VARIANT_BOOL', 'WCHAR', 'WIN32_FIND_DATAA', 'WIN32_FIND_DATAW',
'WORD', 'WPARAM', '_COORD', '_FILETIME', '_LARGE_INTEGER',
'_POINTL', '_RECTL', '_SMALL_RECT', '_ULARGE_INTEGER', 'tagMSG',
'tagPOINT', 'tagRECT', 'tagSIZE']
import sys
from ctypes import *
_array_type = type(Array)
def _other_endian(typ):
"""Return the type with the 'other' byte order. Simple types like
c_int and so on already have __ctype_be__ and __ctype_le__
attributes which contain the types, for more complicated types
arrays and structures are supported.
"""
# check _OTHER_ENDIAN attribute (present if typ is primitive type)
if hasattr(typ, _OTHER_ENDIAN):
return getattr(typ, _OTHER_ENDIAN)
# if typ is array
if isinstance(typ, _array_type):
return _other_endian(typ._type_) * typ._length_
# if typ is structure
if issubclass(typ, Structure):
return typ
raise TypeError("This type does not support other endian: %s" % typ)
class _swapped_meta(type(Structure)):
def __setattr__(self, attrname, value):
if attrname == "_fields_":
fields = []
for desc in value:
name = desc[0]
typ = desc[1]
rest = desc[2:]
fields.append((name, _other_endian(typ)) + rest)
value = fields
super(_swapped_meta, self).__setattr__(attrname, value)
################################################################
# Note: The Structure metaclass checks for the *presence* (not the
# value!) of a _swapped_bytes_ attribute to determine the bit order in
# structures containing bit fields.
if sys.byteorder == "little":
_OTHER_ENDIAN = "__ctype_be__"
LittleEndianStructure = Structure
class BigEndianStructure(Structure):
"""Structure with big endian byte order"""
__metaclass__ = _swapped_meta
_swappedbytes_ = None
elif sys.byteorder == "big":
_OTHER_ENDIAN = "__ctype_le__"
BigEndianStructure = Structure
class LittleEndianStructure(Structure):
"""Structure with little endian byte order"""
__metaclass__ = _swapped_meta
_swappedbytes_ = None
else:
raise RuntimeError("Invalid byteorder")
"""create and manipulate C data types in Python"""
import os as _os, sys as _sys
__version__ = "1.1.0"
from _ctypes import Union, Structure, Array
from _ctypes import _Pointer
from _ctypes import CFuncPtr as _CFuncPtr
from _ctypes import __version__ as _ctypes_version
from _ctypes import RTLD_LOCAL, RTLD_GLOBAL
from _ctypes import ArgumentError
from struct import calcsize as _calcsize
if __version__ != _ctypes_version:
raise Exception("Version number mismatch", __version__, _ctypes_version)
if _os.name in ("nt", "ce"):
from _ctypes import FormatError
DEFAULT_MODE = RTLD_LOCAL
if _os.name == "posix" and _sys.platform == "darwin":
# On OS X 10.3, we use RTLD_GLOBAL as default mode
# because RTLD_LOCAL does not work at least on some
# libraries. OS X 10.3 is Darwin 7, so we check for
# that.
if int(_os.uname()[2].split('.')[0]) < 8:
DEFAULT_MODE = RTLD_GLOBAL
from _ctypes import FUNCFLAG_CDECL as _FUNCFLAG_CDECL, \
FUNCFLAG_PYTHONAPI as _FUNCFLAG_PYTHONAPI, \
FUNCFLAG_USE_ERRNO as _FUNCFLAG_USE_ERRNO, \
FUNCFLAG_USE_LASTERROR as _FUNCFLAG_USE_LASTERROR
"""
WINOLEAPI -> HRESULT
WINOLEAPI_(type)
STDMETHODCALLTYPE
STDMETHOD(name)
STDMETHOD_(type, name)
STDAPICALLTYPE
"""
def create_string_buffer(init, size=None):
"""create_string_buffer(aString) -> character array
create_string_buffer(anInteger) -> character array
create_string_buffer(aString, anInteger) -> character array
"""
if isinstance(init, (str, unicode)):
if size is None:
size = len(init)+1
buftype = c_char * size
buf = buftype()
buf.value = init
return buf
elif isinstance(init, (int, long)):
buftype = c_char * init
buf = buftype()
return buf
raise TypeError(init)
def c_buffer(init, size=None):
## "deprecated, use create_string_buffer instead"
## import warnings
## warnings.warn("c_buffer is deprecated, use create_string_buffer instead",
## DeprecationWarning, stacklevel=2)
return create_string_buffer(init, size)
_c_functype_cache = {}
def CFUNCTYPE(restype, *argtypes, **kw):
"""CFUNCTYPE(restype, *argtypes,
use_errno=False, use_last_error=False) -> function prototype.
restype: the result type
argtypes: a sequence specifying the argument types
The function prototype can be called in different ways to create a
callable object:
prototype(integer address) -> foreign function
prototype(callable) -> create and return a C callable function from callable
prototype(integer index, method name[, paramflags]) -> foreign function calling a COM method
prototype((ordinal number, dll object)[, paramflags]) -> foreign function exported by ordinal
prototype((function name, dll object)[, paramflags]) -> foreign function exported by name
"""
flags = _FUNCFLAG_CDECL
if kw.pop("use_errno", False):
flags |= _FUNCFLAG_USE_ERRNO
if kw.pop("use_last_error", False):
flags |= _FUNCFLAG_USE_LASTERROR
if kw:
raise ValueError("unexpected keyword argument(s) %s" % kw.keys())
try:
return _c_functype_cache[(restype, argtypes, flags)]
except KeyError:
class CFunctionType(_CFuncPtr):
_argtypes_ = argtypes
_restype_ = restype
_flags_ = flags
_c_functype_cache[(restype, argtypes, flags)] = CFunctionType
return CFunctionType
if _os.name in ("nt", "ce"):
from _ctypes import LoadLibrary as _dlopen
from _ctypes import FUNCFLAG_STDCALL as _FUNCFLAG_STDCALL
if _os.name == "ce":
# 'ce' doesn't have the stdcall calling convention
_FUNCFLAG_STDCALL = _FUNCFLAG_CDECL
_win_functype_cache = {}
def WINFUNCTYPE(restype, *argtypes, **kw):
# docstring set later (very similar to CFUNCTYPE.__doc__)
flags = _FUNCFLAG_STDCALL
if kw.pop("use_errno", False):
flags |= _FUNCFLAG_USE_ERRNO
if kw.pop("use_last_error", False):
flags |= _FUNCFLAG_USE_LASTERROR
if kw:
raise ValueError("unexpected keyword argument(s) %s" % kw.keys())
try:
return _win_functype_cache[(restype, argtypes, flags)]
except KeyError:
class WinFunctionType(_CFuncPtr):
_argtypes_ = argtypes
_restype_ = restype
_flags_ = flags
_win_functype_cache[(restype, argtypes, flags)] = WinFunctionType
return WinFunctionType
if WINFUNCTYPE.__doc__:
WINFUNCTYPE.__doc__ = CFUNCTYPE.__doc__.replace("CFUNCTYPE", "WINFUNCTYPE")
elif _os.name == "posix":
from _ctypes import dlopen as _dlopen
from _ctypes import sizeof, byref, addressof, alignment, resize
from _ctypes import get_errno, set_errno
from _ctypes import _SimpleCData
def _check_size(typ, typecode=None):
# Check if sizeof(ctypes_type) against struct.calcsize. This
# should protect somewhat against a misconfigured libffi.
from struct import calcsize
if typecode is None:
# Most _type_ codes are the same as used in struct
typecode = typ._type_
actual, required = sizeof(typ), calcsize(typecode)
if actual != required:
raise SystemError("sizeof(%s) wrong: %d instead of %d" % \
(typ, actual, required))
class py_object(_SimpleCData):
_type_ = "O"
def __repr__(self):
try:
return super(py_object, self).__repr__()
except ValueError:
return "%s(<NULL>)" % type(self).__name__
_check_size(py_object, "P")
class c_short(_SimpleCData):
_type_ = "h"
_check_size(c_short)
class c_ushort(_SimpleCData):
_type_ = "H"
_check_size(c_ushort)
class c_long(_SimpleCData):
_type_ = "l"
_check_size(c_long)
class c_ulong(_SimpleCData):
_type_ = "L"
_check_size(c_ulong)
if _calcsize("i") == _calcsize("l"):
# if int and long have the same size, make c_int an alias for c_long
c_int = c_long
c_uint = c_ulong
else:
class c_int(_SimpleCData):
_type_ = "i"
_check_size(c_int)
class c_uint(_SimpleCData):
_type_ = "I"
_check_size(c_uint)
class c_float(_SimpleCData):
_type_ = "f"
_check_size(c_float)
class c_double(_SimpleCData):
_type_ = "d"
_check_size(c_double)
class c_longdouble(_SimpleCData):
_type_ = "g"
if sizeof(c_longdouble) == sizeof(c_double):
c_longdouble = c_double
if _calcsize("l") == _calcsize("q"):
# if long and long long have the same size, make c_longlong an alias for c_long
c_longlong = c_long
c_ulonglong = c_ulong
else:
class c_longlong(_SimpleCData):
_type_ = "q"
_check_size(c_longlong)
class c_ulonglong(_SimpleCData):
_type_ = "Q"
## def from_param(cls, val):
## return ('d', float(val), val)
## from_param = classmethod(from_param)
_check_size(c_ulonglong)
class c_ubyte(_SimpleCData):
_type_ = "B"
c_ubyte.__ctype_le__ = c_ubyte.__ctype_be__ = c_ubyte
# backward compatibility:
##c_uchar = c_ubyte
_check_size(c_ubyte)
class c_byte(_SimpleCData):
_type_ = "b"
c_byte.__ctype_le__ = c_byte.__ctype_be__ = c_byte
_check_size(c_byte)
class c_char(_SimpleCData):
_type_ = "c"
c_char.__ctype_le__ = c_char.__ctype_be__ = c_char
_check_size(c_char)
class c_char_p(_SimpleCData):
_type_ = "z"
if _os.name == "nt":
def __repr__(self):
if not windll.kernel32.IsBadStringPtrA(self, -1):
return "%s(%r)" % (self.__class__.__name__, self.value)
return "%s(%s)" % (self.__class__.__name__, cast(self, c_void_p).value)
else:
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, cast(self, c_void_p).value)
_check_size(c_char_p, "P")
class c_void_p(_SimpleCData):
_type_ = "P"
c_voidp = c_void_p # backwards compatibility (to a bug)
_check_size(c_void_p)
class c_bool(_SimpleCData):
_type_ = "?"
from _ctypes import POINTER, pointer, _pointer_type_cache
def _reset_cache():
_pointer_type_cache.clear()
_c_functype_cache.clear()
if _os.name in ("nt", "ce"):
_win_functype_cache.clear()
# _SimpleCData.c_wchar_p_from_param
POINTER(c_wchar).from_param = c_wchar_p.from_param
# _SimpleCData.c_char_p_from_param
POINTER(c_char).from_param = c_char_p.from_param
_pointer_type_cache[None] = c_void_p
# XXX for whatever reasons, creating the first instance of a callback
# function is needed for the unittests on Win64 to succeed. This MAY
# be a compiler bug, since the problem occurs only when _ctypes is
# compiled with the MS SDK compiler. Or an uninitialized variable?
CFUNCTYPE(c_int)(lambda: None)
try:
from _ctypes import set_conversion_mode
except ImportError:
pass
else:
if _os.name in ("nt", "ce"):
set_conversion_mode("mbcs", "ignore")
else:
set_conversion_mode("ascii", "strict")
class c_wchar_p(_SimpleCData):
_type_ = "Z"
class c_wchar(_SimpleCData):
_type_ = "u"
def create_unicode_buffer(init, size=None):
"""create_unicode_buffer(aString) -> character array
create_unicode_buffer(anInteger) -> character array
create_unicode_buffer(aString, anInteger) -> character array
"""
if isinstance(init, (str, unicode)):
if size is None:
size = len(init)+1
buftype = c_wchar * size
buf = buftype()
buf.value = init
return buf
elif isinstance(init, (int, long)):
buftype = c_wchar * init
buf = buftype()
return buf
raise TypeError(init)
# XXX Deprecated
def SetPointerType(pointer, cls):
if _pointer_type_cache.get(cls, None) is not None:
raise RuntimeError("This type already exists in the cache")
if id(pointer) not in _pointer_type_cache:
raise RuntimeError("What's this???")
pointer.set_type(cls)
_pointer_type_cache[cls] = pointer
del _pointer_type_cache[id(pointer)]
# XXX Deprecated
def ARRAY(typ, len):
return typ * len
################################################################
class CDLL(object):
"""An instance of this class represents a loaded dll/shared
library, exporting functions using the standard C calling
convention (named 'cdecl' on Windows).
The exported functions can be accessed as attributes, or by
indexing with the function name. Examples:
<obj>.qsort -> callable object
<obj>['qsort'] -> callable object
Calling the functions releases the Python GIL during the call and
reacquires it afterwards.
"""
_func_flags_ = _FUNCFLAG_CDECL
_func_restype_ = c_int
# default values for repr
_name = '<uninitialized>'
_handle = 0
_FuncPtr = None
def __init__(self, name, mode=DEFAULT_MODE, handle=None,
use_errno=False,
use_last_error=False):
self._name = name
flags = self._func_flags_
if use_errno:
flags |= _FUNCFLAG_USE_ERRNO
if use_last_error:
flags |= _FUNCFLAG_USE_LASTERROR
class _FuncPtr(_CFuncPtr):
_flags_ = flags
_restype_ = self._func_restype_
self._FuncPtr = _FuncPtr
if handle is None:
self._handle = _dlopen(self._name, mode)
else:
self._handle = handle
def __repr__(self):
return "<%s '%s', handle %x at %x>" % \
(self.__class__.__name__, self._name,
(self._handle & (_sys.maxint*2 + 1)),
id(self) & (_sys.maxint*2 + 1))
def __getattr__(self, name):
if name.startswith('__') and name.endswith('__'):
raise AttributeError(name)
func = self.__getitem__(name)
setattr(self, name, func)
return func
def __getitem__(self, name_or_ordinal):
func = self._FuncPtr((name_or_ordinal, self))
if not isinstance(name_or_ordinal, (int, long)):
func.__name__ = name_or_ordinal
return func
class PyDLL(CDLL):
"""This class represents the Python library itself. It allows
accessing Python API functions. The GIL is not released, and
Python exceptions are handled correctly.
"""
_func_flags_ = _FUNCFLAG_CDECL | _FUNCFLAG_PYTHONAPI
if _os.name in ("nt", "ce"):
class WinDLL(CDLL):
"""This class represents a dll exporting functions using the
Windows stdcall calling convention.
"""
_func_flags_ = _FUNCFLAG_STDCALL
# XXX Hm, what about HRESULT as normal parameter?
# Mustn't it derive from c_long then?
from _ctypes import _check_HRESULT, _SimpleCData
class HRESULT(_SimpleCData):
_type_ = "l"
# _check_retval_ is called with the function's result when it
# is used as restype. It checks for the FAILED bit, and
# raises a WindowsError if it is set.
#
# The _check_retval_ method is implemented in C, so that the
# method definition itself is not included in the traceback
# when it raises an error - that is what we want (and Python
# doesn't have a way to raise an exception in the caller's
# frame).
_check_retval_ = _check_HRESULT
class OleDLL(CDLL):
"""This class represents a dll exporting functions using the
Windows stdcall calling convention, and returning HRESULT.
HRESULT error values are automatically raised as WindowsError
exceptions.
"""
_func_flags_ = _FUNCFLAG_STDCALL
_func_restype_ = HRESULT
class LibraryLoader(object):
def __init__(self, dlltype):
self._dlltype = dlltype
def __getattr__(self, name):
if name[0] == '_':
raise AttributeError(name)
dll = self._dlltype(name)
setattr(self, name, dll)
return dll
def __getitem__(self, name):
return getattr(self, name)
def LoadLibrary(self, name):
return self._dlltype(name)
cdll = LibraryLoader(CDLL)
pydll = LibraryLoader(PyDLL)
if _os.name in ("nt", "ce"):
pythonapi = PyDLL("python dll", None, _sys.dllhandle)
elif _sys.platform == "cygwin":
pythonapi = PyDLL("libpython%d.%d.dll" % _sys.version_info[:2])
elif _sys.platform == "cli": # Need to determine how to do this
pythonapi = None
else:
pythonapi = PyDLL(None)
if _os.name in ("nt", "ce"):
windll = LibraryLoader(WinDLL)
oledll = LibraryLoader(OleDLL)
if _os.name == "nt":
GetLastError = windll.kernel32.GetLastError
else:
GetLastError = windll.coredll.GetLastError
from _ctypes import get_last_error, set_last_error
def WinError(code=None, descr=None):
if code is None:
code = GetLastError()
if descr is None:
descr = FormatError(code).strip()
return WindowsError(code, descr)
if sizeof(c_uint) == sizeof(c_void_p):
c_size_t = c_uint
c_ssize_t = c_int
elif sizeof(c_ulong) == sizeof(c_void_p):
c_size_t = c_ulong
c_ssize_t = c_long
elif sizeof(c_ulonglong) == sizeof(c_void_p):
c_size_t = c_ulonglong
c_ssize_t = c_longlong
# functions
from _ctypes import _memmove_addr, _memset_addr, _string_at_addr, _cast_addr
## void *memmove(void *, const void *, size_t);
memmove = CFUNCTYPE(c_void_p, c_void_p, c_void_p, c_size_t)(_memmove_addr)
## void *memset(void *, int, size_t)
memset = CFUNCTYPE(c_void_p, c_void_p, c_int, c_size_t)(_memset_addr)
def PYFUNCTYPE(restype, *argtypes):
class CFunctionType(_CFuncPtr):
_argtypes_ = argtypes
_restype_ = restype
_flags_ = _FUNCFLAG_CDECL | _FUNCFLAG_PYTHONAPI
return CFunctionType
_cast = PYFUNCTYPE(py_object, c_void_p, py_object, py_object)(_cast_addr)
def cast(obj, typ):
return _cast(obj, obj, typ)
_string_at = PYFUNCTYPE(py_object, c_void_p, c_int)(_string_at_addr)
def string_at(ptr, size=-1):
"""string_at(addr[, size]) -> string
Return the string at addr."""
return _string_at(ptr, size)
try:
from _ctypes import _wstring_at_addr
except ImportError:
pass
else:
_wstring_at = PYFUNCTYPE(py_object, c_void_p, c_int)(_wstring_at_addr)
def wstring_at(ptr, size=-1):
"""wstring_at(addr[, size]) -> string
Return the string at addr."""
return _wstring_at(ptr, size)
if _os.name in ("nt", "ce"): # COM stuff
def DllGetClassObject(rclsid, riid, ppv):
try:
ccom = __import__("comtypes.server.inprocserver", globals(), locals(), ['*'])
except ImportError:
return -2147221231 # CLASS_E_CLASSNOTAVAILABLE
else:
return ccom.DllGetClassObject(rclsid, riid, ppv)
def DllCanUnloadNow():
try:
ccom = __import__("comtypes.server.inprocserver", globals(), locals(), ['*'])
except ImportError:
return 0 # S_OK
return ccom.DllCanUnloadNow()
from ctypes._endian import BigEndianStructure, LittleEndianStructure
# Fill in specifically-sized types
c_int8 = c_byte
c_uint8 = c_ubyte
for kind in [c_short, c_int, c_long, c_longlong]:
if sizeof(kind) == 2: c_int16 = kind
elif sizeof(kind) == 4: c_int32 = kind
elif sizeof(kind) == 8: c_int64 = kind
for kind in [c_ushort, c_uint, c_ulong, c_ulonglong]:
if sizeof(kind) == 2: c_uint16 = kind
elif sizeof(kind) == 4: c_uint32 = kind
elif sizeof(kind) == 8: c_uint64 = kind
del(kind)
_reset_cache()
# Copyright (c) 2004 Python Software Foundation.
# All rights reserved.
# Written by Eric Price <eprice at tjhsst.edu>
# and Facundo Batista <facundo at taniquetil.com.ar>
# and Raymond Hettinger <python at rcn.com>
# and Aahz <aahz at pobox.com>
# and Tim Peters
# This module is currently Py2.3 compatible and should be kept that way
# unless a major compelling advantage arises. IOW, 2.3 compatibility is
# strongly preferred, but not guaranteed.
# Also, this module should be kept in sync with the latest updates of
# the IBM specification as it evolves. Those updates will be treated
# as bug fixes (deviation from the spec is a compatibility, usability
# bug) and will be backported. At this point the spec is stabilizing
# and the updates are becoming fewer, smaller, and less significant.
"""
This is a Py2.3 implementation of decimal floating point arithmetic based on
the General Decimal Arithmetic Specification:
http://speleotrove.com/decimal/decarith.html
and IEEE standard 854-1987:
http://en.wikipedia.org/wiki/IEEE_854-1987
Decimal floating point has finite precision with arbitrarily large bounds.
The purpose of this module is to support arithmetic using familiar
"schoolhouse" rules and to avoid some of the tricky representation
issues associated with binary floating point. The package is especially
useful for financial applications or for contexts where users have
expectations that are at odds with binary floating point (for instance,
in binary floating point, 1.00 % 0.1 gives 0.09999999999999995 instead
of the expected Decimal('0.00') returned by decimal floating point).
Here are some examples of using the decimal module:
>>> from decimal import *
>>> setcontext(ExtendedContext)
>>> Decimal(0)
Decimal('0')
>>> Decimal('1')
Decimal('1')
>>> Decimal('-.0123')
Decimal('-0.0123')
>>> Decimal(123456)
Decimal('123456')
>>> Decimal('123.45e12345678901234567890')
Decimal('1.2345E+12345678901234567892')
>>> Decimal('1.33') + Decimal('1.27')
Decimal('2.60')
>>> Decimal('12.34') + Decimal('3.87') - Decimal('18.41')
Decimal('-2.20')
>>> dig = Decimal(1)
>>> print dig / Decimal(3)
0.333333333
>>> getcontext().prec = 18
>>> print dig / Decimal(3)
0.333333333333333333
>>> print dig.sqrt()
1
>>> print Decimal(3).sqrt()
1.73205080756887729
>>> print Decimal(3) ** 123
4.85192780976896427E+58
>>> inf = Decimal(1) / Decimal(0)
>>> print inf
Infinity
>>> neginf = Decimal(-1) / Decimal(0)
>>> print neginf
-Infinity
>>> print neginf + inf
NaN
>>> print neginf * inf
-Infinity
>>> print dig / 0
Infinity
>>> getcontext().traps[DivisionByZero] = 1
>>> print dig / 0
Traceback (most recent call last):
...
...
...
DivisionByZero: x / 0
>>> c = Context()
>>> c.traps[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.divide(Decimal(0), Decimal(0))
Decimal('NaN')
>>> c.traps[InvalidOperation] = 1
>>> print c.flags[InvalidOperation]
1
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> print c.divide(Decimal(0), Decimal(0))
Traceback (most recent call last):
...
...
...
InvalidOperation: 0 / 0
>>> print c.flags[InvalidOperation]
1
>>> c.flags[InvalidOperation] = 0
>>> c.traps[InvalidOperation] = 0
>>> print c.divide(Decimal(0), Decimal(0))
NaN
>>> print c.flags[InvalidOperation]
1
>>>
"""
__all__ = [
# Two major classes
'Decimal', 'Context',
# Contexts
'DefaultContext', 'BasicContext', 'ExtendedContext',
# Exceptions
'DecimalException', 'Clamped', 'InvalidOperation', 'DivisionByZero',
'Inexact', 'Rounded', 'Subnormal', 'Overflow', 'Underflow',
# Constants for use in setting up contexts
'ROUND_DOWN', 'ROUND_HALF_UP', 'ROUND_HALF_EVEN', 'ROUND_CEILING',
'ROUND_FLOOR', 'ROUND_UP', 'ROUND_HALF_DOWN', 'ROUND_05UP',
# Functions for manipulating contexts
'setcontext', 'getcontext', 'localcontext'
]
__version__ = '1.70' # Highest version of the spec this complies with
import math as _math
import numbers as _numbers
try:
from collections import namedtuple as _namedtuple
DecimalTuple = _namedtuple('DecimalTuple', 'sign digits exponent')
except ImportError:
DecimalTuple = lambda *args: args
# Rounding
ROUND_DOWN = 'ROUND_DOWN'
ROUND_HALF_UP = 'ROUND_HALF_UP'
ROUND_HALF_EVEN = 'ROUND_HALF_EVEN'
ROUND_CEILING = 'ROUND_CEILING'
ROUND_FLOOR = 'ROUND_FLOOR'
ROUND_UP = 'ROUND_UP'
ROUND_HALF_DOWN = 'ROUND_HALF_DOWN'
ROUND_05UP = 'ROUND_05UP'
# Errors
class DecimalException(ArithmeticError):
"""Base exception class.
Used exceptions derive from this.
If an exception derives from another exception besides this (such as
Underflow (Inexact, Rounded, Subnormal) that indicates that it is only
called if the others are present. This isn't actually used for
anything, though.
handle -- Called when context._raise_error is called and the
trap_enabler is not set. First argument is self, second is the
context. More arguments can be given, those being after
the explanation in _raise_error (For example,
context._raise_error(NewError, '(-x)!', self._sign) would
call NewError().handle(context, self._sign).)
To define a new exception, it should be sufficient to have it derive
from DecimalException.
"""
def handle(self, context, *args):
pass
class Clamped(DecimalException):
"""Exponent of a 0 changed to fit bounds.
This occurs and signals clamped if the exponent of a result has been
altered in order to fit the constraints of a specific concrete
representation. This may occur when the exponent of a zero result would
be outside the bounds of a representation, or when a large normal
number would have an encoded exponent that cannot be represented. In
this latter case, the exponent is reduced to fit and the corresponding
number of zero digits are appended to the coefficient ("fold-down").
"""
class InvalidOperation(DecimalException):
"""An invalid operation was performed.
Various bad things cause this:
Something creates a signaling NaN
-INF + INF
0 * (+-)INF
(+-)INF / (+-)INF
x % 0
(+-)INF % x
x._rescale( non-integer )
sqrt(-x) , x > 0
0 ** 0
x ** (non-integer)
x ** (+-)INF
An operand is invalid
The result of the operation after these is a quiet positive NaN,
except when the cause is a signaling NaN, in which case the result is
also a quiet NaN, but with the original sign, and an optional
diagnostic information.
"""
def handle(self, context, *args):
if args:
ans = _dec_from_triple(args[0]._sign, args[0]._int, 'n', True)
return ans._fix_nan(context)
return _NaN
class ConversionSyntax(InvalidOperation):
"""Trying to convert badly formed string.
This occurs and signals invalid-operation if a string is being
converted to a number and it does not conform to the numeric string
syntax. The result is [0,qNaN].
"""
def handle(self, context, *args):
return _NaN
class DivisionByZero(DecimalException, ZeroDivisionError):
"""Division by 0.
This occurs and signals division-by-zero if division of a finite number
by zero was attempted (during a divide-integer or divide operation, or a
power operation with negative right-hand operand), and the dividend was
not zero.
The result of the operation is [sign,inf], where sign is the exclusive
or of the signs of the operands for divide, or is 1 for an odd power of
-0, for power.
"""
def handle(self, context, sign, *args):
return _SignedInfinity[sign]
class DivisionImpossible(InvalidOperation):
"""Cannot perform the division adequately.
This occurs and signals invalid-operation if the integer result of a
divide-integer or remainder operation had too many digits (would be
longer than precision). The result is [0,qNaN].
"""
def handle(self, context, *args):
return _NaN
class DivisionUndefined(InvalidOperation, ZeroDivisionError):
"""Undefined result of division.
This occurs and signals invalid-operation if division by zero was
attempted (during a divide-integer, divide, or remainder operation), and
the dividend is also zero. The result is [0,qNaN].
"""
def handle(self, context, *args):
return _NaN
class Inexact(DecimalException):
"""Had to round, losing information.
This occurs and signals inexact whenever the result of an operation is
not exact (that is, it needed to be rounded and any discarded digits
were non-zero), or if an overflow or underflow condition occurs. The
result in all cases is unchanged.
The inexact signal may be tested (or trapped) to determine if a given
operation (or sequence of operations) was inexact.
"""
class InvalidContext(InvalidOperation):
"""Invalid context. Unknown rounding, for example.
This occurs and signals invalid-operation if an invalid context was
detected during an operation. This can occur if contexts are not checked
on creation and either the precision exceeds the capability of the
underlying concrete representation or an unknown or unsupported rounding
was specified. These aspects of the context need only be checked when
the values are required to be used. The result is [0,qNaN].
"""
def handle(self, context, *args):
return _NaN
class Rounded(DecimalException):
"""Number got rounded (not necessarily changed during rounding).
This occurs and signals rounded whenever the result of an operation is
rounded (that is, some zero or non-zero digits were discarded from the
coefficient), or if an overflow or underflow condition occurs. The
result in all cases is unchanged.
The rounded signal may be tested (or trapped) to determine if a given
operation (or sequence of operations) caused a loss of precision.
"""
class Subnormal(DecimalException):
"""Exponent < Emin before rounding.
This occurs and signals subnormal whenever the result of a conversion or
operation is subnormal (that is, its adjusted exponent is less than
Emin, before any rounding). The result in all cases is unchanged.
The subnormal signal may be tested (or trapped) to determine if a given
or operation (or sequence of operations) yielded a subnormal result.
"""
class Overflow(Inexact, Rounded):
"""Numerical overflow.
This occurs and signals overflow if the adjusted exponent of a result
(from a conversion or from an operation that is not an attempt to divide
by zero), after rounding, would be greater than the largest value that
can be handled by the implementation (the value Emax).
The result depends on the rounding mode:
For round-half-up and round-half-even (and for round-half-down and
round-up, if implemented), the result of the operation is [sign,inf],
where sign is the sign of the intermediate result. For round-down, the
result is the largest finite number that can be represented in the
current precision, with the sign of the intermediate result. For
round-ceiling, the result is the same as for round-down if the sign of
the intermediate result is 1, or is [0,inf] otherwise. For round-floor,
the result is the same as for round-down if the sign of the intermediate
result is 0, or is [1,inf] otherwise. In all cases, Inexact and Rounded
will also be raised.
"""
def handle(self, context, sign, *args):
if context.rounding in (ROUND_HALF_UP, ROUND_HALF_EVEN,
ROUND_HALF_DOWN, ROUND_UP):
return _SignedInfinity[sign]
if sign == 0:
if context.rounding == ROUND_CEILING:
return _SignedInfinity[sign]
return _dec_from_triple(sign, '9'*context.prec,
context.Emax-context.prec+1)
if sign == 1:
if context.rounding == ROUND_FLOOR:
return _SignedInfinity[sign]
return _dec_from_triple(sign, '9'*context.prec,
context.Emax-context.prec+1)
class Underflow(Inexact, Rounded, Subnormal):
"""Numerical underflow with result rounded to 0.
This occurs and signals underflow if a result is inexact and the
adjusted exponent of the result would be smaller (more negative) than
the smallest value that can be handled by the implementation (the value
Emin). That is, the result is both inexact and subnormal.
The result after an underflow will be a subnormal number rounded, if
necessary, so that its exponent is not less than Etiny. This may result
in 0 with the sign of the intermediate result and an exponent of Etiny.
In all cases, Inexact, Rounded, and Subnormal will also be raised.
"""
# List of public traps and flags
_signals = [Clamped, DivisionByZero, Inexact, Overflow, Rounded,
Underflow, InvalidOperation, Subnormal]
# Map conditions (per the spec) to signals
_condition_map = {ConversionSyntax:InvalidOperation,
DivisionImpossible:InvalidOperation,
DivisionUndefined:InvalidOperation,
InvalidContext:InvalidOperation}
##### Context Functions ##################################################
# The getcontext() and setcontext() function manage access to a thread-local
# current context. Py2.4 offers direct support for thread locals. If that
# is not available, use threading.currentThread() which is slower but will
# work for older Pythons. If threads are not part of the build, create a
# mock threading object with threading.local() returning the module namespace.
try:
import threading
except ImportError:
# Python was compiled without threads; create a mock object instead
import sys
class MockThreading(object):
def local(self, sys=sys):
return sys.modules[__name__]
threading = MockThreading()
del sys, MockThreading
try:
threading.local
except AttributeError:
# To fix reloading, force it to create a new context
# Old contexts have different exceptions in their dicts, making problems.
if hasattr(threading.currentThread(), '__decimal_context__'):
del threading.currentThread().__decimal_context__
def setcontext(context):
"""Set this thread's context to context."""
if context in (DefaultContext, BasicContext, ExtendedContext):
context = context.copy()
context.clear_flags()
threading.currentThread().__decimal_context__ = context
def getcontext():
"""Returns this thread's context.
If this thread does not yet have a context, returns
a new context and sets this thread's context.
New contexts are copies of DefaultContext.
"""
try:
return threading.currentThread().__decimal_context__
except AttributeError:
context = Context()
threading.currentThread().__decimal_context__ = context
return context
else:
local = threading.local()
if hasattr(local, '__decimal_context__'):
del local.__decimal_context__
def getcontext(_local=local):
"""Returns this thread's context.
If this thread does not yet have a context, returns
a new context and sets this thread's context.
New contexts are copies of DefaultContext.
"""
try:
return _local.__decimal_context__
except AttributeError:
context = Context()
_local.__decimal_context__ = context
return context
def setcontext(context, _local=local):
"""Set this thread's context to context."""
if context in (DefaultContext, BasicContext, ExtendedContext):
context = context.copy()
context.clear_flags()
_local.__decimal_context__ = context
del threading, local # Don't contaminate the namespace
def localcontext(ctx=None):
"""Return a context manager for a copy of the supplied context
Uses a copy of the current context if no context is specified
The returned context manager creates a local decimal context
in a with statement:
def sin(x):
with localcontext() as ctx:
ctx.prec += 2
# Rest of sin calculation algorithm
# uses a precision 2 greater than normal
return +s # Convert result to normal precision
def sin(x):
with localcontext(ExtendedContext):
# Rest of sin calculation algorithm
# uses the Extended Context from the
# General Decimal Arithmetic Specification
return +s # Convert result to normal context
>>> setcontext(DefaultContext)
>>> print getcontext().prec
28
>>> with localcontext():
... ctx = getcontext()
... ctx.prec += 2
... print ctx.prec
...
30
>>> with localcontext(ExtendedContext):
... print getcontext().prec
...
9
>>> print getcontext().prec
28
"""
if ctx is None: ctx = getcontext()
return _ContextManager(ctx)
##### Decimal class #######################################################
class Decimal(object):
"""Floating point class for decimal arithmetic."""
__slots__ = ('_exp','_int','_sign', '_is_special')
# Generally, the value of the Decimal instance is given by
# (-1)**_sign * _int * 10**_exp
# Special values are signified by _is_special == True
# We're immutable, so use __new__ not __init__
def __new__(cls, value="0", context=None):
"""Create a decimal point instance.
>>> Decimal('3.14') # string input
Decimal('3.14')
>>> Decimal((0, (3, 1, 4), -2)) # tuple (sign, digit_tuple, exponent)
Decimal('3.14')
>>> Decimal(314) # int or long
Decimal('314')
>>> Decimal(Decimal(314)) # another decimal instance
Decimal('314')
>>> Decimal(' 3.14 \\n') # leading and trailing whitespace okay
Decimal('3.14')
"""
# Note that the coefficient, self._int, is actually stored as
# a string rather than as a tuple of digits. This speeds up
# the "digits to integer" and "integer to digits" conversions
# that are used in almost every arithmetic operation on
# Decimals. This is an internal detail: the as_tuple function
# and the Decimal constructor still deal with tuples of
# digits.
self = object.__new__(cls)
import sys
if sys.platform == 'cli':
import System
if isinstance(value, System.Decimal):
value = str(value)
# From a string
# REs insist on real strings, so we can too.
if isinstance(value, basestring):
m = _parser(value.strip())
if m is None:
if context is None:
context = getcontext()
return context._raise_error(ConversionSyntax,
"Invalid literal for Decimal: %r" % value)
if m.group('sign') == "-":
self._sign = 1
else:
self._sign = 0
intpart = m.group('int')
if intpart is not None:
# finite number
fracpart = m.group('frac') or ''
exp = int(m.group('exp') or '0')
self._int = str(int(intpart+fracpart))
self._exp = exp - len(fracpart)
self._is_special = False
else:
diag = m.group('diag')
if diag is not None:
# NaN
self._int = str(int(diag or '0')).lstrip('0')
if m.group('signal'):
self._exp = 'N'
else:
self._exp = 'n'
else:
# infinity
self._int = '0'
self._exp = 'F'
self._is_special = True
return self
# From an integer
if isinstance(value, (int,long)):
if value >= 0:
self._sign = 0
else:
self._sign = 1
self._exp = 0
self._int = str(abs(value))
self._is_special = False
return self
# From another decimal
if isinstance(value, Decimal):
self._exp = value._exp
self._sign = value._sign
self._int = value._int
self._is_special = value._is_special
return self
# From an internal working value
if isinstance(value, _WorkRep):
self._sign = value.sign
self._int = str(value.int)
self._exp = int(value.exp)
self._is_special = False
return self
# tuple/list conversion (possibly from as_tuple())
if isinstance(value, (list,tuple)):
if len(value) != 3:
raise ValueError('Invalid tuple size in creation of Decimal '
'from list or tuple. The list or tuple '
'should have exactly three elements.')
# process sign. The isinstance test rejects floats
if not (isinstance(value[0], (int, long)) and value[0] in (0,1)):
raise ValueError("Invalid sign. The first value in the tuple "
"should be an integer; either 0 for a "
"positive number or 1 for a negative number.")
self._sign = value[0]
if value[2] == 'F':
# infinity: value[1] is ignored
self._int = '0'
self._exp = value[2]
self._is_special = True
else:
# process and validate the digits in value[1]
digits = []
for digit in value[1]:
if isinstance(digit, (int, long)) and 0 <= digit <= 9:
# skip leading zeros
if digits or digit != 0:
digits.append(digit)
else:
raise ValueError("The second value in the tuple must "
"be composed of integers in the range "
"0 through 9.")
if value[2] in ('n', 'N'):
# NaN: digits form the diagnostic
self._int = ''.join(map(str, digits))
self._exp = value[2]
self._is_special = True
elif isinstance(value[2], (int, long)):
# finite number: digits give the coefficient
self._int = ''.join(map(str, digits or [0]))
self._exp = value[2]
self._is_special = False
else:
raise ValueError("The third value in the tuple must "
"be an integer, or one of the "
"strings 'F', 'n', 'N'.")
return self
if isinstance(value, float):
value = Decimal.from_float(value)
self._exp = value._exp
self._sign = value._sign
self._int = value._int
self._is_special = value._is_special
return self
raise TypeError("Cannot convert %r to Decimal" % value)
# @classmethod, but @decorator is not valid Python 2.3 syntax, so
# don't use it (see notes on Py2.3 compatibility at top of file)
def from_float(cls, f):
"""Converts a float to a decimal number, exactly.
Note that Decimal.from_float(0.1) is not the same as Decimal('0.1').
Since 0.1 is not exactly representable in binary floating point, the
value is stored as the nearest representable value which is
0x1.999999999999ap-4. The exact equivalent of the value in decimal
is 0.1000000000000000055511151231257827021181583404541015625.
>>> Decimal.from_float(0.1)
Decimal('0.1000000000000000055511151231257827021181583404541015625')
>>> Decimal.from_float(float('nan'))
Decimal('NaN')
>>> Decimal.from_float(float('inf'))
Decimal('Infinity')
>>> Decimal.from_float(-float('inf'))
Decimal('-Infinity')
>>> Decimal.from_float(-0.0)
Decimal('-0')
"""
if isinstance(f, (int, long)): # handle integer inputs
return cls(f)
if _math.isinf(f) or _math.isnan(f): # raises TypeError if not a float
return cls(repr(f))
if _math.copysign(1.0, f) == 1.0:
sign = 0
else:
sign = 1
n, d = abs(f).as_integer_ratio()
k = d.bit_length() - 1
result = _dec_from_triple(sign, str(n*5**k), -k)
if cls is Decimal:
return result
else:
return cls(result)
from_float = classmethod(from_float)
def _isnan(self):
"""Returns whether the number is not actually one.
0 if a number
1 if NaN
2 if sNaN
"""
if self._is_special:
exp = self._exp
if exp == 'n':
return 1
elif exp == 'N':
return 2
return 0
def _isinfinity(self):
"""Returns whether the number is infinite
0 if finite or not a number
1 if +INF
-1 if -INF
"""
if self._exp == 'F':
if self._sign:
return -1
return 1
return 0
def _check_nans(self, other=None, context=None):
"""Returns whether the number is not actually one.
if self, other are sNaN, signal
if self, other are NaN return nan
return 0
Done before operations.
"""
self_is_nan = self._isnan()
if other is None:
other_is_nan = False
else:
other_is_nan = other._isnan()
if self_is_nan or other_is_nan:
if context is None:
context = getcontext()
if self_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
self)
if other_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
other)
if self_is_nan:
return self._fix_nan(context)
return other._fix_nan(context)
return 0
def _compare_check_nans(self, other, context):
"""Version of _check_nans used for the signaling comparisons
compare_signal, __le__, __lt__, __ge__, __gt__.
Signal InvalidOperation if either self or other is a (quiet
or signaling) NaN. Signaling NaNs take precedence over quiet
NaNs.
Return 0 if neither operand is a NaN.
"""
if context is None:
context = getcontext()
if self._is_special or other._is_special:
if self.is_snan():
return context._raise_error(InvalidOperation,
'comparison involving sNaN',
self)
elif other.is_snan():
return context._raise_error(InvalidOperation,
'comparison involving sNaN',
other)
elif self.is_qnan():
return context._raise_error(InvalidOperation,
'comparison involving NaN',
self)
elif other.is_qnan():
return context._raise_error(InvalidOperation,
'comparison involving NaN',
other)
return 0
def __nonzero__(self):
"""Return True if self is nonzero; otherwise return False.
NaNs and infinities are considered nonzero.
"""
return self._is_special or self._int != '0'
def _cmp(self, other):
"""Compare the two non-NaN decimal instances self and other.
Returns -1 if self < other, 0 if self == other and 1
if self > other. This routine is for internal use only."""
if self._is_special or other._is_special:
self_inf = self._isinfinity()
other_inf = other._isinfinity()
if self_inf == other_inf:
return 0
elif self_inf < other_inf:
return -1
else:
return 1
# check for zeros; Decimal('0') == Decimal('-0')
if not self:
if not other:
return 0
else:
return -((-1)**other._sign)
if not other:
return (-1)**self._sign
# If different signs, neg one is less
if other._sign < self._sign:
return -1
if self._sign < other._sign:
return 1
self_adjusted = self.adjusted()
other_adjusted = other.adjusted()
if self_adjusted == other_adjusted:
self_padded = self._int + '0'*(self._exp - other._exp)
other_padded = other._int + '0'*(other._exp - self._exp)
if self_padded == other_padded:
return 0
elif self_padded < other_padded:
return -(-1)**self._sign
else:
return (-1)**self._sign
elif self_adjusted > other_adjusted:
return (-1)**self._sign
else: # self_adjusted < other_adjusted
return -((-1)**self._sign)
# Note: The Decimal standard doesn't cover rich comparisons for
# Decimals. In particular, the specification is silent on the
# subject of what should happen for a comparison involving a NaN.
# We take the following approach:
#
# == comparisons involving a quiet NaN always return False
# != comparisons involving a quiet NaN always return True
# == or != comparisons involving a signaling NaN signal
# InvalidOperation, and return False or True as above if the
# InvalidOperation is not trapped.
# <, >, <= and >= comparisons involving a (quiet or signaling)
# NaN signal InvalidOperation, and return False if the
# InvalidOperation is not trapped.
#
# This behavior is designed to conform as closely as possible to
# that specified by IEEE 754.
def __eq__(self, other, context=None):
other = _convert_other(other, allow_float=True)
if other is NotImplemented:
return other
if self._check_nans(other, context):
return False
return self._cmp(other) == 0
def __ne__(self, other, context=None):
other = _convert_other(other, allow_float=True)
if other is NotImplemented:
return other
if self._check_nans(other, context):
return True
return self._cmp(other) != 0
def __lt__(self, other, context=None):
other = _convert_other(other, allow_float=True)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) < 0
def __le__(self, other, context=None):
other = _convert_other(other, allow_float=True)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) <= 0
def __gt__(self, other, context=None):
other = _convert_other(other, allow_float=True)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) > 0
def __ge__(self, other, context=None):
other = _convert_other(other, allow_float=True)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) >= 0
def compare(self, other, context=None):
"""Compares one to another.
-1 => a < b
0 => a = b
1 => a > b
NaN => one is NaN
Like __cmp__, but returns Decimal instances.
"""
other = _convert_other(other, raiseit=True)
# Compare(NaN, NaN) = NaN
if (self._is_special or other and other._is_special):
ans = self._check_nans(other, context)
if ans:
return ans
return Decimal(self._cmp(other))
def __hash__(self):
"""x.__hash__() <==> hash(x)"""
# Decimal integers must hash the same as the ints
#
# The hash of a nonspecial noninteger Decimal must depend only
# on the value of that Decimal, and not on its representation.
# For example: hash(Decimal('100E-1')) == hash(Decimal('10')).
# Equality comparisons involving signaling nans can raise an
# exception; since equality checks are implicitly and
# unpredictably used when checking set and dict membership, we
# prevent signaling nans from being used as set elements or
# dict keys by making __hash__ raise an exception.
if self._is_special:
if self.is_snan():
raise TypeError('Cannot hash a signaling NaN value.')
elif self.is_nan():
# 0 to match hash(float('nan'))
return 0
else:
# values chosen to match hash(float('inf')) and
# hash(float('-inf')).
if self._sign:
return -271828
else:
return 314159
# In Python 2.7, we're allowing comparisons (but not
# arithmetic operations) between floats and Decimals; so if
# a Decimal instance is exactly representable as a float then
# its hash should match that of the float.
self_as_float = float(self)
if Decimal.from_float(self_as_float) == self:
return hash(self_as_float)
if self._isinteger():
op = _WorkRep(self.to_integral_value())
# to make computation feasible for Decimals with large
# exponent, we use the fact that hash(n) == hash(m) for
# any two nonzero integers n and m such that (i) n and m
# have the same sign, and (ii) n is congruent to m modulo
# 2**64-1. So we can replace hash((-1)**s*c*10**e) with
# hash((-1)**s*c*pow(10, e, 2**64-1).
return hash((-1)**op.sign*op.int*pow(10, op.exp, 2**64-1))
# The value of a nonzero nonspecial Decimal instance is
# faithfully represented by the triple consisting of its sign,
# its adjusted exponent, and its coefficient with trailing
# zeros removed.
return hash((self._sign,
self._exp+len(self._int),
self._int.rstrip('0')))
def as_tuple(self):
"""Represents the number as a triple tuple.
To show the internals exactly as they are.
"""
return DecimalTuple(self._sign, tuple(map(int, self._int)), self._exp)
def __repr__(self):
"""Represents the number as an instance of Decimal."""
# Invariant: eval(repr(d)) == d
return "Decimal('%s')" % str(self)
def __str__(self, eng=False, context=None):
"""Return string representation of the number in scientific notation.
Captures all of the information in the underlying representation.
"""
sign = ['', '-'][self._sign]
if self._is_special:
if self._exp == 'F':
return sign + 'Infinity'
elif self._exp == 'n':
return sign + 'NaN' + self._int
else: # self._exp == 'N'
return sign + 'sNaN' + self._int
# number of digits of self._int to left of decimal point
leftdigits = self._exp + len(self._int)
# dotplace is number of digits of self._int to the left of the
# decimal point in the mantissa of the output string (that is,
# after adjusting the exponent)
if self._exp <= 0 and leftdigits > -6:
# no exponent required
dotplace = leftdigits
elif not eng:
# usual scientific notation: 1 digit on left of the point
dotplace = 1
elif self._int == '0':
# engineering notation, zero
dotplace = (leftdigits + 1) % 3 - 1
else:
# engineering notation, nonzero
dotplace = (leftdigits - 1) % 3 + 1
if dotplace <= 0:
intpart = '0'
fracpart = '.' + '0'*(-dotplace) + self._int
elif dotplace >= len(self._int):
intpart = self._int+'0'*(dotplace-len(self._int))
fracpart = ''
else:
intpart = self._int[:dotplace]
fracpart = '.' + self._int[dotplace:]
if leftdigits == dotplace:
exp = ''
else:
if context is None:
context = getcontext()
exp = ['e', 'E'][context.capitals] + "%+d" % (leftdigits-dotplace)
return sign + intpart + fracpart + exp
def to_eng_string(self, context=None):
"""Convert to a string, using engineering notation if an exponent is needed.
Engineering notation has an exponent which is a multiple of 3. This
can leave up to 3 digits to the left of the decimal place and may
require the addition of either one or two trailing zeros.
"""
return self.__str__(eng=True, context=context)
def __neg__(self, context=None):
"""Returns a copy with the sign switched.
Rounds, if it has reason.
"""
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
if context is None:
context = getcontext()
if not self and context.rounding != ROUND_FLOOR:
# -Decimal('0') is Decimal('0'), not Decimal('-0'), except
# in ROUND_FLOOR rounding mode.
ans = self.copy_abs()
else:
ans = self.copy_negate()
return ans._fix(context)
def __pos__(self, context=None):
"""Returns a copy, unless it is a sNaN.
Rounds the number (if more than precision digits)
"""
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
if context is None:
context = getcontext()
if not self and context.rounding != ROUND_FLOOR:
# + (-0) = 0, except in ROUND_FLOOR rounding mode.
ans = self.copy_abs()
else:
ans = Decimal(self)
return ans._fix(context)
def __abs__(self, round=True, context=None):
"""Returns the absolute value of self.
If the keyword argument 'round' is false, do not round. The
expression self.__abs__(round=False) is equivalent to
self.copy_abs().
"""
if not round:
return self.copy_abs()
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
if self._sign:
ans = self.__neg__(context=context)
else:
ans = self.__pos__(context=context)
return ans
def __add__(self, other, context=None):
"""Returns self + other.
-INF + INF (or the reverse) cause InvalidOperation errors.
"""
other = _convert_other(other)
if other is NotImplemented:
return other
if context is None:
context = getcontext()
if self._is_special or other._is_special:
ans = self._check_nans(other, context)
if ans:
return ans
if self._isinfinity():
# If both INF, same sign => same as both, opposite => error.
if self._sign != other._sign and other._isinfinity():
return context._raise_error(InvalidOperation, '-INF + INF')
return Decimal(self)
if other._isinfinity():
return Decimal(other) # Can't both be infinity here
exp = min(self._exp, other._exp)
negativezero = 0
if context.rounding == ROUND_FLOOR and self._sign != other._sign:
# If the answer is 0, the sign should be negative, in this case.
negativezero = 1
if not self and not other:
sign = min(self._sign, other._sign)
if negativezero:
sign = 1
ans = _dec_from_triple(sign, '0', exp)
ans = ans._fix(context)
return ans
if not self:
exp = max(exp, other._exp - context.prec-1)
ans = other._rescale(exp, context.rounding)
ans = ans._fix(context)
return ans
if not other:
exp = max(exp, self._exp - context.prec-1)
ans = self._rescale(exp, context.rounding)
ans = ans._fix(context)
return ans
op1 = _WorkRep(self)
op2 = _WorkRep(other)
op1, op2 = _normalize(op1, op2, context.prec)
result = _WorkRep()
if op1.sign != op2.sign:
# Equal and opposite
if op1.int == op2.int:
ans = _dec_from_triple(negativezero, '0', exp)
ans = ans._fix(context)
return ans
if op1.int < op2.int:
op1, op2 = op2, op1
# OK, now abs(op1) > abs(op2)
if op1.sign == 1:
result.sign = 1
op1.sign, op2.sign = op2.sign, op1.sign
else:
result.sign = 0
# So we know the sign, and op1 > 0.
elif op1.sign == 1:
result.sign = 1
op1.sign, op2.sign = (0, 0)
else:
result.sign = 0
# Now, op1 > abs(op2) > 0
if op2.sign == 0:
result.int = op1.int + op2.int
else:
result.int = op1.int - op2.int
result.exp = op1.exp
ans = Decimal(result)
ans = ans._fix(context)
return ans
__radd__ = __add__
def __sub__(self, other, context=None):
"""Return self - other"""
other = _convert_other(other)
if other is NotImplemented:
return other
if self._is_special or other._is_special:
ans = self._check_nans(other, context=context)
if ans:
return ans
# self - other is computed as self + other.copy_negate()
return self.__add__(other.copy_negate(), context=context)
def __rsub__(self, other, context=None):
"""Return other - self"""
other = _convert_other(other)
if other is NotImplemented:
return other
return other.__sub__(self, context=context)
def __mul__(self, other, context=None):
"""Return self * other.
(+-) INF * 0 (or its reverse) raise InvalidOperation.
"""
other = _convert_other(other)
if other is NotImplemented:
return other
if context is None:
context = getcontext()
resultsign = self._sign ^ other._sign
if self._is_special or other._is_special:
ans = self._check_nans(other, context)
if ans:
return ans
if self._isinfinity():
if not other:
return context._raise_error(InvalidOperation, '(+-)INF * 0')
return _SignedInfinity[resultsign]
if other._isinfinity():
if not self:
return context._raise_error(InvalidOperation, '0 * (+-)INF')
return _SignedInfinity[resultsign]
resultexp = self._exp + other._exp
# Special case for multiplying by zero
if not self or not other:
ans = _dec_from_triple(resultsign, '0', resultexp)
# Fixing in case the exponent is out of bounds
ans = ans._fix(context)
return ans
# Special case for multiplying by power of 10
if self._int == '1':
ans = _dec_from_triple(resultsign, other._int, resultexp)
ans = ans._fix(context)
return ans
if other._int == '1':
ans = _dec_from_triple(resultsign, self._int, resultexp)
ans = ans._fix(context)
return ans
op1 = _WorkRep(self)
op2 = _WorkRep(other)
ans = _dec_from_triple(resultsign, str(op1.int * op2.int), resultexp)
ans = ans._fix(context)
return ans
__rmul__ = __mul__
def __truediv__(self, other, context=None):
"""Return self / other."""
other = _convert_other(other)
if other is NotImplemented:
return NotImplemented
if context is None:
context = getcontext()
sign = self._sign ^ other._sign
if self._is_special or other._is_special:
ans = self._check_nans(other, context)
if ans:
return ans
if self._isinfinity() and other._isinfinity():
return context._raise_error(InvalidOperation, '(+-)INF/(+-)INF')
if self._isinfinity():
return _SignedInfinity[sign]
if other._isinfinity():
context._raise_error(Clamped, 'Division by infinity')
return _dec_from_triple(sign, '0', context.Etiny())
# Special cases for zeroes
if not other:
if not self:
return context._raise_error(DivisionUndefined, '0 / 0')
return context._raise_error(DivisionByZero, 'x / 0', sign)
if not self:
exp = self._exp - other._exp
coeff = 0
else:
# OK, so neither = 0, INF or NaN
shift = len(other._int) - len(self._int) + context.prec + 1
exp = self._exp - other._exp - shift
op1 = _WorkRep(self)
op2 = _WorkRep(other)
if shift >= 0:
coeff, remainder = divmod(op1.int * 10**shift, op2.int)
else:
coeff, remainder = divmod(op1.int, op2.int * 10**-shift)
if remainder:
# result is not exact; adjust to ensure correct rounding
if coeff % 5 == 0:
coeff += 1
else:
# result is exact; get as close to ideal exponent as possible
ideal_exp = self._exp - other._exp
while exp < ideal_exp and coeff % 10 == 0:
coeff //= 10
exp += 1
ans = _dec_from_triple(sign, str(coeff), exp)
return ans._fix(context)
def _divide(self, other, context):
"""Return (self // other, self % other), to context.prec precision.
Assumes that neither self nor other is a NaN, that self is not
infinite and that other is nonzero.
"""
sign = self._sign ^ other._sign
if other._isinfinity():
ideal_exp = self._exp
else:
ideal_exp = min(self._exp, other._exp)
expdiff = self.adjusted() - other.adjusted()
if not self or other._isinfinity() or expdiff <= -2:
return (_dec_from_triple(sign, '0', 0),
self._rescale(ideal_exp, context.rounding))
if expdiff <= context.prec:
op1 = _WorkRep(self)
op2 = _WorkRep(other)
if op1.exp >= op2.exp:
op1.int *= 10**(op1.exp - op2.exp)
else:
op2.int *= 10**(op2.exp - op1.exp)
q, r = divmod(op1.int, op2.int)
if q < 10**context.prec:
return (_dec_from_triple(sign, str(q), 0),
_dec_from_triple(self._sign, str(r), ideal_exp))
# Here the quotient is too large to be representable
ans = context._raise_error(DivisionImpossible,
'quotient too large in //, % or divmod')
return ans, ans
def __rtruediv__(self, other, context=None):
"""Swaps self/other and returns __truediv__."""
other = _convert_other(other)
if other is NotImplemented:
return other
return other.__truediv__(self, context=context)
__div__ = __truediv__
__rdiv__ = __rtruediv__
def __divmod__(self, other, context=None):
"""
Return (self // other, self % other)
"""
other = _convert_other(other)
if other is NotImplemented:
return other
if context is None:
context = getcontext()
ans = self._check_nans(other, context)
if ans:
return (ans, ans)
sign = self._sign ^ other._sign
if self._isinfinity():
if other._isinfinity():
ans = context._raise_error(InvalidOperation, 'divmod(INF, INF)')
return ans, ans
else:
return (_SignedInfinity[sign],
context._raise_error(InvalidOperation, 'INF % x'))
if not other:
if not self:
ans = context._raise_error(DivisionUndefined, 'divmod(0, 0)')
return ans, ans
else:
return (context._raise_error(DivisionByZero, 'x // 0', sign),
context._raise_error(InvalidOperation, 'x % 0'))
quotient, remainder = self._divide(other, context)
remainder = remainder._fix(context)
return quotient, remainder
def __rdivmod__(self, other, context=None):
"""Swaps self/other and returns __divmod__."""
other = _convert_other(other)
if other is NotImplemented:
return other
return other.__divmod__(self, context=context)
def __mod__(self, other, context=None):
"""
self % other
"""
other = _convert_other(other)
if other is NotImplemented:
return other
if context is None:
context = getcontext()
ans = self._check_nans(other, context)
if ans:
return ans
if self._isinfinity():
return context._raise_error(InvalidOperation, 'INF % x')
elif not other:
if self:
return context._raise_error(InvalidOperation, 'x % 0')
else:
return context._raise_error(DivisionUndefined, '0 % 0')
remainder = self._divide(other, context)[1]
remainder = remainder._fix(context)
return remainder
def __rmod__(self, other, context=None):
"""Swaps self/other and returns __mod__."""
other = _convert_other(other)
if other is NotImplemented:
return other
return other.__mod__(self, context=context)
def remainder_near(self, other, context=None):
"""
Remainder nearest to 0- abs(remainder-near) <= other/2
"""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
ans = self._check_nans(other, context)
if ans:
return ans
# self == +/-infinity -> InvalidOperation
if self._isinfinity():
return context._raise_error(InvalidOperation,
'remainder_near(infinity, x)')
# other == 0 -> either InvalidOperation or DivisionUndefined
if not other:
if self:
return context._raise_error(InvalidOperation,
'remainder_near(x, 0)')
else:
return context._raise_error(DivisionUndefined,
'remainder_near(0, 0)')
# other = +/-infinity -> remainder = self
if other._isinfinity():
ans = Decimal(self)
return ans._fix(context)
# self = 0 -> remainder = self, with ideal exponent
ideal_exponent = min(self._exp, other._exp)
if not self:
ans = _dec_from_triple(self._sign, '0', ideal_exponent)
return ans._fix(context)
# catch most cases of large or small quotient
expdiff = self.adjusted() - other.adjusted()
if expdiff >= context.prec + 1:
# expdiff >= prec+1 => abs(self/other) > 10**prec
return context._raise_error(DivisionImpossible)
if expdiff <= -2:
# expdiff <= -2 => abs(self/other) < 0.1
ans = self._rescale(ideal_exponent, context.rounding)
return ans._fix(context)
# adjust both arguments to have the same exponent, then divide
op1 = _WorkRep(self)
op2 = _WorkRep(other)
if op1.exp >= op2.exp:
op1.int *= 10**(op1.exp - op2.exp)
else:
op2.int *= 10**(op2.exp - op1.exp)
q, r = divmod(op1.int, op2.int)
# remainder is r*10**ideal_exponent; other is +/-op2.int *
# 10**ideal_exponent. Apply correction to ensure that
# abs(remainder) <= abs(other)/2
if 2*r + (q&1) > op2.int:
r -= op2.int
q += 1
if q >= 10**context.prec:
return context._raise_error(DivisionImpossible)
# result has same sign as self unless r is negative
sign = self._sign
if r < 0:
sign = 1-sign
r = -r
ans = _dec_from_triple(sign, str(r), ideal_exponent)
return ans._fix(context)
def __floordiv__(self, other, context=None):
"""self // other"""
other = _convert_other(other)
if other is NotImplemented:
return other
if context is None:
context = getcontext()
ans = self._check_nans(other, context)
if ans:
return ans
if self._isinfinity():
if other._isinfinity():
return context._raise_error(InvalidOperation, 'INF // INF')
else:
return _SignedInfinity[self._sign ^ other._sign]
if not other:
if self:
return context._raise_error(DivisionByZero, 'x // 0',
self._sign ^ other._sign)
else:
return context._raise_error(DivisionUndefined, '0 // 0')
return self._divide(other, context)[0]
def __rfloordiv__(self, other, context=None):
"""Swaps self/other and returns __floordiv__."""
other = _convert_other(other)
if other is NotImplemented:
return other
return other.__floordiv__(self, context=context)
def __float__(self):
"""Float representation."""
if self._isnan():
if self.is_snan():
raise ValueError("Cannot convert signaling NaN to float")
s = "-nan" if self._sign else "nan"
else:
s = str(self)
return float(s)
def __int__(self):
"""Converts self to an int, truncating if necessary."""
if self._is_special:
if self._isnan():
raise ValueError("Cannot convert NaN to integer")
elif self._isinfinity():
raise OverflowError("Cannot convert infinity to integer")
s = (-1)**self._sign
if self._exp >= 0:
return s*int(self._int)*10**self._exp
else:
return s*int(self._int[:self._exp] or '0')
__trunc__ = __int__
def real(self):
return self
real = property(real)
def imag(self):
return Decimal(0)
imag = property(imag)
def conjugate(self):
return self
def __complex__(self):
return complex(float(self))
def __long__(self):
"""Converts to a long.
Equivalent to long(int(self))
"""
return long(self.__int__())
def _fix_nan(self, context):
"""Decapitate the payload of a NaN to fit the context"""
payload = self._int
# maximum length of payload is precision if _clamp=0,
# precision-1 if _clamp=1.
max_payload_len = context.prec - context._clamp
if len(payload) > max_payload_len:
payload = payload[len(payload)-max_payload_len:].lstrip('0')
return _dec_from_triple(self._sign, payload, self._exp, True)
return Decimal(self)
def _fix(self, context):
"""Round if it is necessary to keep self within prec precision.
Rounds and fixes the exponent. Does not raise on a sNaN.
Arguments:
self - Decimal instance
context - context used.
"""
if self._is_special:
if self._isnan():
# decapitate payload if necessary
return self._fix_nan(context)
else:
# self is +/-Infinity; return unaltered
return Decimal(self)
# if self is zero then exponent should be between Etiny and
# Emax if _clamp==0, and between Etiny and Etop if _clamp==1.
Etiny = context.Etiny()
Etop = context.Etop()
if not self:
exp_max = [context.Emax, Etop][context._clamp]
new_exp = min(max(self._exp, Etiny), exp_max)
if new_exp != self._exp:
context._raise_error(Clamped)
return _dec_from_triple(self._sign, '0', new_exp)
else:
return Decimal(self)
# exp_min is the smallest allowable exponent of the result,
# equal to max(self.adjusted()-context.prec+1, Etiny)
exp_min = len(self._int) + self._exp - context.prec
if exp_min > Etop:
# overflow: exp_min > Etop iff self.adjusted() > Emax
ans = context._raise_error(Overflow, 'above Emax', self._sign)
context._raise_error(Inexact)
context._raise_error(Rounded)
return ans
self_is_subnormal = exp_min < Etiny
if self_is_subnormal:
exp_min = Etiny
# round if self has too many digits
if self._exp < exp_min:
digits = len(self._int) + self._exp - exp_min
if digits < 0:
self = _dec_from_triple(self._sign, '1', exp_min-1)
digits = 0
rounding_method = self._pick_rounding_function[context.rounding]
changed = rounding_method(self, digits)
coeff = self._int[:digits] or '0'
if changed > 0:
coeff = str(int(coeff)+1)
if len(coeff) > context.prec:
coeff = coeff[:-1]
exp_min += 1
# check whether the rounding pushed the exponent out of range
if exp_min > Etop:
ans = context._raise_error(Overflow, 'above Emax', self._sign)
else:
ans = _dec_from_triple(self._sign, coeff, exp_min)
# raise the appropriate signals, taking care to respect
# the precedence described in the specification
if changed and self_is_subnormal:
context._raise_error(Underflow)
if self_is_subnormal:
context._raise_error(Subnormal)
if changed:
context._raise_error(Inexact)
context._raise_error(Rounded)
if not ans:
# raise Clamped on underflow to 0
context._raise_error(Clamped)
return ans
if self_is_subnormal:
context._raise_error(Subnormal)
# fold down if _clamp == 1 and self has too few digits
if context._clamp == 1 and self._exp > Etop:
context._raise_error(Clamped)
self_padded = self._int + '0'*(self._exp - Etop)
return _dec_from_triple(self._sign, self_padded, Etop)
# here self was representable to begin with; return unchanged
return Decimal(self)
# for each of the rounding functions below:
# self is a finite, nonzero Decimal
# prec is an integer satisfying 0 <= prec < len(self._int)
#
# each function returns either -1, 0, or 1, as follows:
# 1 indicates that self should be rounded up (away from zero)
# 0 indicates that self should be truncated, and that all the
# digits to be truncated are zeros (so the value is unchanged)
# -1 indicates that there are nonzero digits to be truncated
def _round_down(self, prec):
"""Also known as round-towards-0, truncate."""
if _all_zeros(self._int, prec):
return 0
else:
return -1
def _round_up(self, prec):
"""Rounds away from 0."""
return -self._round_down(prec)
def _round_half_up(self, prec):
"""Rounds 5 up (away from 0)"""
if self._int[prec] in '56789':
return 1
elif _all_zeros(self._int, prec):
return 0
else:
return -1
def _round_half_down(self, prec):
"""Round 5 down"""
if _exact_half(self._int, prec):
return -1
else:
return self._round_half_up(prec)
def _round_half_even(self, prec):
"""Round 5 to even, rest to nearest."""
if _exact_half(self._int, prec) and \
(prec == 0 or self._int[prec-1] in '02468'):
return -1
else:
return self._round_half_up(prec)
def _round_ceiling(self, prec):
"""Rounds up (not away from 0 if negative.)"""
if self._sign:
return self._round_down(prec)
else:
return -self._round_down(prec)
def _round_floor(self, prec):
"""Rounds down (not towards 0 if negative)"""
if not self._sign:
return self._round_down(prec)
else:
return -self._round_down(prec)
def _round_05up(self, prec):
"""Round down unless digit prec-1 is 0 or 5."""
if prec and self._int[prec-1] not in '05':
return self._round_down(prec)
else:
return -self._round_down(prec)
_pick_rounding_function = dict(
ROUND_DOWN = _round_down,
ROUND_UP = _round_up,
ROUND_HALF_UP = _round_half_up,
ROUND_HALF_DOWN = _round_half_down,
ROUND_HALF_EVEN = _round_half_even,
ROUND_CEILING = _round_ceiling,
ROUND_FLOOR = _round_floor,
ROUND_05UP = _round_05up,
)
def fma(self, other, third, context=None):
"""Fused multiply-add.
Returns self*other+third with no rounding of the intermediate
product self*other.
self and other are multiplied together, with no rounding of
the result. The third operand is then added to the result,
and a single final rounding is performed.
"""
other = _convert_other(other, raiseit=True)
# compute product; raise InvalidOperation if either operand is
# a signaling NaN or if the product is zero times infinity.
if self._is_special or other._is_special:
if context is None:
context = getcontext()
if self._exp == 'N':
return context._raise_error(InvalidOperation, 'sNaN', self)
if other._exp == 'N':
return context._raise_error(InvalidOperation, 'sNaN', other)
if self._exp == 'n':
product = self
elif other._exp == 'n':
product = other
elif self._exp == 'F':
if not other:
return context._raise_error(InvalidOperation,
'INF * 0 in fma')
product = _SignedInfinity[self._sign ^ other._sign]
elif other._exp == 'F':
if not self:
return context._raise_error(InvalidOperation,
'0 * INF in fma')
product = _SignedInfinity[self._sign ^ other._sign]
else:
product = _dec_from_triple(self._sign ^ other._sign,
str(int(self._int) * int(other._int)),
self._exp + other._exp)
third = _convert_other(third, raiseit=True)
return product.__add__(third, context)
def _power_modulo(self, other, modulo, context=None):
"""Three argument version of __pow__"""
# if can't convert other and modulo to Decimal, raise
# TypeError; there's no point returning NotImplemented (no
# equivalent of __rpow__ for three argument pow)
other = _convert_other(other, raiseit=True)
modulo = _convert_other(modulo, raiseit=True)
if context is None:
context = getcontext()
# deal with NaNs: if there are any sNaNs then first one wins,
# (i.e. behaviour for NaNs is identical to that of fma)
self_is_nan = self._isnan()
other_is_nan = other._isnan()
modulo_is_nan = modulo._isnan()
if self_is_nan or other_is_nan or modulo_is_nan:
if self_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
self)
if other_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
other)
if modulo_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
modulo)
if self_is_nan:
return self._fix_nan(context)
if other_is_nan:
return other._fix_nan(context)
return modulo._fix_nan(context)
# check inputs: we apply same restrictions as Python's pow()
if not (self._isinteger() and
other._isinteger() and
modulo._isinteger()):
return context._raise_error(InvalidOperation,
'pow() 3rd argument not allowed '
'unless all arguments are integers')
if other < 0:
return context._raise_error(InvalidOperation,
'pow() 2nd argument cannot be '
'negative when 3rd argument specified')
if not modulo:
return context._raise_error(InvalidOperation,
'pow() 3rd argument cannot be 0')
# additional restriction for decimal: the modulus must be less
# than 10**prec in absolute value
if modulo.adjusted() >= context.prec:
return context._raise_error(InvalidOperation,
'insufficient precision: pow() 3rd '
'argument must not have more than '
'precision digits')
# define 0**0 == NaN, for consistency with two-argument pow
# (even though it hurts!)
if not other and not self:
return context._raise_error(InvalidOperation,
'at least one of pow() 1st argument '
'and 2nd argument must be nonzero; '
'0**0 is not defined')
# compute sign of result
if other._iseven():
sign = 0
else:
sign = self._sign
# convert modulo to a Python integer, and self and other to
# Decimal integers (i.e. force their exponents to be >= 0)
modulo = abs(int(modulo))
base = _WorkRep(self.to_integral_value())
exponent = _WorkRep(other.to_integral_value())
# compute result using integer pow()
base = (base.int % modulo * pow(10, base.exp, modulo)) % modulo
for i in xrange(exponent.exp):
base = pow(base, 10, modulo)
base = pow(base, exponent.int, modulo)
return _dec_from_triple(sign, str(base), 0)
def _power_exact(self, other, p):
"""Attempt to compute self**other exactly.
Given Decimals self and other and an integer p, attempt to
compute an exact result for the power self**other, with p
digits of precision. Return None if self**other is not
exactly representable in p digits.
Assumes that elimination of special cases has already been
performed: self and other must both be nonspecial; self must
be positive and not numerically equal to 1; other must be
nonzero. For efficiency, other._exp should not be too large,
so that 10**abs(other._exp) is a feasible calculation."""
# In the comments below, we write x for the value of self and y for the
# value of other. Write x = xc*10**xe and abs(y) = yc*10**ye, with xc
# and yc positive integers not divisible by 10.
# The main purpose of this method is to identify the *failure*
# of x**y to be exactly representable with as little effort as
# possible. So we look for cheap and easy tests that
# eliminate the possibility of x**y being exact. Only if all
# these tests are passed do we go on to actually compute x**y.
# Here's the main idea. Express y as a rational number m/n, with m and
# n relatively prime and n>0. Then for x**y to be exactly
# representable (at *any* precision), xc must be the nth power of a
# positive integer and xe must be divisible by n. If y is negative
# then additionally xc must be a power of either 2 or 5, hence a power
# of 2**n or 5**n.
#
# There's a limit to how small |y| can be: if y=m/n as above
# then:
#
# (1) if xc != 1 then for the result to be representable we
# need xc**(1/n) >= 2, and hence also xc**|y| >= 2. So
# if |y| <= 1/nbits(xc) then xc < 2**nbits(xc) <=
# 2**(1/|y|), hence xc**|y| < 2 and the result is not
# representable.
#
# (2) if xe != 0, |xe|*(1/n) >= 1, so |xe|*|y| >= 1. Hence if
# |y| < 1/|xe| then the result is not representable.
#
# Note that since x is not equal to 1, at least one of (1) and
# (2) must apply. Now |y| < 1/nbits(xc) iff |yc|*nbits(xc) <
# 10**-ye iff len(str(|yc|*nbits(xc)) <= -ye.
#
# There's also a limit to how large y can be, at least if it's
# positive: the normalized result will have coefficient xc**y,
# so if it's representable then xc**y < 10**p, and y <
# p/log10(xc). Hence if y*log10(xc) >= p then the result is
# not exactly representable.
# if len(str(abs(yc*xe)) <= -ye then abs(yc*xe) < 10**-ye,
# so |y| < 1/xe and the result is not representable.
# Similarly, len(str(abs(yc)*xc_bits)) <= -ye implies |y|
# < 1/nbits(xc).
x = _WorkRep(self)
xc, xe = x.int, x.exp
while xc % 10 == 0:
xc //= 10
xe += 1
y = _WorkRep(other)
yc, ye = y.int, y.exp
while yc % 10 == 0:
yc //= 10
ye += 1
# case where xc == 1: result is 10**(xe*y), with xe*y
# required to be an integer
if xc == 1:
xe *= yc
# result is now 10**(xe * 10**ye); xe * 10**ye must be integral
while xe % 10 == 0:
xe //= 10
ye += 1
if ye < 0:
return None
exponent = xe * 10**ye
if y.sign == 1:
exponent = -exponent
# if other is a nonnegative integer, use ideal exponent
if other._isinteger() and other._sign == 0:
ideal_exponent = self._exp*int(other)
zeros = min(exponent-ideal_exponent, p-1)
else:
zeros = 0
return _dec_from_triple(0, '1' + '0'*zeros, exponent-zeros)
# case where y is negative: xc must be either a power
# of 2 or a power of 5.
if y.sign == 1:
last_digit = xc % 10
if last_digit in (2,4,6,8):
# quick test for power of 2
if xc & -xc != xc:
return None
# now xc is a power of 2; e is its exponent
e = _nbits(xc)-1
# We now have:
#
# x = 2**e * 10**xe, e > 0, and y < 0.
#
# The exact result is:
#
# x**y = 5**(-e*y) * 10**(e*y + xe*y)
#
# provided that both e*y and xe*y are integers. Note that if
# 5**(-e*y) >= 10**p, then the result can't be expressed
# exactly with p digits of precision.
#
# Using the above, we can guard against large values of ye.
# 93/65 is an upper bound for log(10)/log(5), so if
#
# ye >= len(str(93*p//65))
#
# then
#
# -e*y >= -y >= 10**ye > 93*p/65 > p*log(10)/log(5),
#
# so 5**(-e*y) >= 10**p, and the coefficient of the result
# can't be expressed in p digits.
# emax >= largest e such that 5**e < 10**p.
emax = p*93//65
if ye >= len(str(emax)):
return None
# Find -e*y and -xe*y; both must be integers
e = _decimal_lshift_exact(e * yc, ye)
xe = _decimal_lshift_exact(xe * yc, ye)
if e is None or xe is None:
return None
if e > emax:
return None
xc = 5**e
elif last_digit == 5:
# e >= log_5(xc) if xc is a power of 5; we have
# equality all the way up to xc=5**2658
e = _nbits(xc)*28//65
xc, remainder = divmod(5**e, xc)
if remainder:
return None
while xc % 5 == 0:
xc //= 5
e -= 1
# Guard against large values of ye, using the same logic as in
# the 'xc is a power of 2' branch. 10/3 is an upper bound for
# log(10)/log(2).
emax = p*10//3
if ye >= len(str(emax)):
return None
e = _decimal_lshift_exact(e * yc, ye)
xe = _decimal_lshift_exact(xe * yc, ye)
if e is None or xe is None:
return None
if e > emax:
return None
xc = 2**e
else:
return None
if xc >= 10**p:
return None
xe = -e-xe
return _dec_from_triple(0, str(xc), xe)
# now y is positive; find m and n such that y = m/n
if ye >= 0:
m, n = yc*10**ye, 1
else:
if xe != 0 and len(str(abs(yc*xe))) <= -ye:
return None
xc_bits = _nbits(xc)
if xc != 1 and len(str(abs(yc)*xc_bits)) <= -ye:
return None
m, n = yc, 10**(-ye)
while m % 2 == n % 2 == 0:
m //= 2
n //= 2
while m % 5 == n % 5 == 0:
m //= 5
n //= 5
# compute nth root of xc*10**xe
if n > 1:
# if 1 < xc < 2**n then xc isn't an nth power
if xc != 1 and xc_bits <= n:
return None
xe, rem = divmod(xe, n)
if rem != 0:
return None
# compute nth root of xc using Newton's method
a = 1L << -(-_nbits(xc)//n) # initial estimate
while True:
q, r = divmod(xc, a**(n-1))
if a <= q:
break
else:
a = (a*(n-1) + q)//n
if not (a == q and r == 0):
return None
xc = a
# now xc*10**xe is the nth root of the original xc*10**xe
# compute mth power of xc*10**xe
# if m > p*100//_log10_lb(xc) then m > p/log10(xc), hence xc**m >
# 10**p and the result is not representable.
if xc > 1 and m > p*100//_log10_lb(xc):
return None
xc = xc**m
xe *= m
if xc > 10**p:
return None
# by this point the result *is* exactly representable
# adjust the exponent to get as close as possible to the ideal
# exponent, if necessary
str_xc = str(xc)
if other._isinteger() and other._sign == 0:
ideal_exponent = self._exp*int(other)
zeros = min(xe-ideal_exponent, p-len(str_xc))
else:
zeros = 0
return _dec_from_triple(0, str_xc+'0'*zeros, xe-zeros)
def __pow__(self, other, modulo=None, context=None):
"""Return self ** other [ % modulo].
With two arguments, compute self**other.
With three arguments, compute (self**other) % modulo. For the
three argument form, the following restrictions on the
arguments hold:
- all three arguments must be integral
- other must be nonnegative
- either self or other (or both) must be nonzero
- modulo must be nonzero and must have at most p digits,
where p is the context precision.
If any of these restrictions is violated the InvalidOperation
flag is raised.
The result of pow(self, other, modulo) is identical to the
result that would be obtained by computing (self**other) %
modulo with unbounded precision, but is computed more
efficiently. It is always exact.
"""
if modulo is not None:
return self._power_modulo(other, modulo, context)
other = _convert_other(other)
if other is NotImplemented:
return other
if context is None:
context = getcontext()
# either argument is a NaN => result is NaN
ans = self._check_nans(other, context)
if ans:
return ans
# 0**0 = NaN (!), x**0 = 1 for nonzero x (including +/-Infinity)
if not other:
if not self:
return context._raise_error(InvalidOperation, '0 ** 0')
else:
return _One
# result has sign 1 iff self._sign is 1 and other is an odd integer
result_sign = 0
if self._sign == 1:
if other._isinteger():
if not other._iseven():
result_sign = 1
else:
# -ve**noninteger = NaN
# (-0)**noninteger = 0**noninteger
if self:
return context._raise_error(InvalidOperation,
'x ** y with x negative and y not an integer')
# negate self, without doing any unwanted rounding
self = self.copy_negate()
# 0**(+ve or Inf)= 0; 0**(-ve or -Inf) = Infinity
if not self:
if other._sign == 0:
return _dec_from_triple(result_sign, '0', 0)
else:
return _SignedInfinity[result_sign]
# Inf**(+ve or Inf) = Inf; Inf**(-ve or -Inf) = 0
if self._isinfinity():
if other._sign == 0:
return _SignedInfinity[result_sign]
else:
return _dec_from_triple(result_sign, '0', 0)
# 1**other = 1, but the choice of exponent and the flags
# depend on the exponent of self, and on whether other is a
# positive integer, a negative integer, or neither
if self == _One:
if other._isinteger():
# exp = max(self._exp*max(int(other), 0),
# 1-context.prec) but evaluating int(other) directly
# is dangerous until we know other is small (other
# could be 1e999999999)
if other._sign == 1:
multiplier = 0
elif other > context.prec:
multiplier = context.prec
else:
multiplier = int(other)
exp = self._exp * multiplier
if exp < 1-context.prec:
exp = 1-context.prec
context._raise_error(Rounded)
else:
context._raise_error(Inexact)
context._raise_error(Rounded)
exp = 1-context.prec
return _dec_from_triple(result_sign, '1'+'0'*-exp, exp)
# compute adjusted exponent of self
self_adj = self.adjusted()
# self ** infinity is infinity if self > 1, 0 if self < 1
# self ** -infinity is infinity if self < 1, 0 if self > 1
if other._isinfinity():
if (other._sign == 0) == (self_adj < 0):
return _dec_from_triple(result_sign, '0', 0)
else:
return _SignedInfinity[result_sign]
# from here on, the result always goes through the call
# to _fix at the end of this function.
ans = None
exact = False
# crude test to catch cases of extreme overflow/underflow. If
# log10(self)*other >= 10**bound and bound >= len(str(Emax))
# then 10**bound >= 10**len(str(Emax)) >= Emax+1 and hence
# self**other >= 10**(Emax+1), so overflow occurs. The test
# for underflow is similar.
bound = self._log10_exp_bound() + other.adjusted()
if (self_adj >= 0) == (other._sign == 0):
# self > 1 and other +ve, or self < 1 and other -ve
# possibility of overflow
if bound >= len(str(context.Emax)):
ans = _dec_from_triple(result_sign, '1', context.Emax+1)
else:
# self > 1 and other -ve, or self < 1 and other +ve
# possibility of underflow to 0
Etiny = context.Etiny()
if bound >= len(str(-Etiny)):
ans = _dec_from_triple(result_sign, '1', Etiny-1)
# try for an exact result with precision +1
if ans is None:
ans = self._power_exact(other, context.prec + 1)
if ans is not None:
if result_sign == 1:
ans = _dec_from_triple(1, ans._int, ans._exp)
exact = True
# usual case: inexact result, x**y computed directly as exp(y*log(x))
if ans is None:
p = context.prec
x = _WorkRep(self)
xc, xe = x.int, x.exp
y = _WorkRep(other)
yc, ye = y.int, y.exp
if y.sign == 1:
yc = -yc
# compute correctly rounded result: start with precision +3,
# then increase precision until result is unambiguously roundable
extra = 3
while True:
coeff, exp = _dpower(xc, xe, yc, ye, p+extra)
if coeff % (5*10**(len(str(coeff))-p-1)):
break
extra += 3
ans = _dec_from_triple(result_sign, str(coeff), exp)
# unlike exp, ln and log10, the power function respects the
# rounding mode; no need to switch to ROUND_HALF_EVEN here
# There's a difficulty here when 'other' is not an integer and
# the result is exact. In this case, the specification
# requires that the Inexact flag be raised (in spite of
# exactness), but since the result is exact _fix won't do this
# for us. (Correspondingly, the Underflow signal should also
# be raised for subnormal results.) We can't directly raise
# these signals either before or after calling _fix, since
# that would violate the precedence for signals. So we wrap
# the ._fix call in a temporary context, and reraise
# afterwards.
if exact and not other._isinteger():
# pad with zeros up to length context.prec+1 if necessary; this
# ensures that the Rounded signal will be raised.
if len(ans._int) <= context.prec:
expdiff = context.prec + 1 - len(ans._int)
ans = _dec_from_triple(ans._sign, ans._int+'0'*expdiff,
ans._exp-expdiff)
# create a copy of the current context, with cleared flags/traps
newcontext = context.copy()
newcontext.clear_flags()
for exception in _signals:
newcontext.traps[exception] = 0
# round in the new context
ans = ans._fix(newcontext)
# raise Inexact, and if necessary, Underflow
newcontext._raise_error(Inexact)
if newcontext.flags[Subnormal]:
newcontext._raise_error(Underflow)
# propagate signals to the original context; _fix could
# have raised any of Overflow, Underflow, Subnormal,
# Inexact, Rounded, Clamped. Overflow needs the correct
# arguments. Note that the order of the exceptions is
# important here.
if newcontext.flags[Overflow]:
context._raise_error(Overflow, 'above Emax', ans._sign)
for exception in Underflow, Subnormal, Inexact, Rounded, Clamped:
if newcontext.flags[exception]:
context._raise_error(exception)
else:
ans = ans._fix(context)
return ans
def __rpow__(self, other, context=None):
"""Swaps self/other and returns __pow__."""
other = _convert_other(other)
if other is NotImplemented:
return other
return other.__pow__(self, context=context)
def normalize(self, context=None):
"""Normalize- strip trailing 0s, change anything equal to 0 to 0e0"""
if context is None:
context = getcontext()
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
dup = self._fix(context)
if dup._isinfinity():
return dup
if not dup:
return _dec_from_triple(dup._sign, '0', 0)
exp_max = [context.Emax, context.Etop()][context._clamp]
end = len(dup._int)
exp = dup._exp
while dup._int[end-1] == '0' and exp < exp_max:
exp += 1
end -= 1
return _dec_from_triple(dup._sign, dup._int[:end], exp)
def quantize(self, exp, rounding=None, context=None, watchexp=True):
"""Quantize self so its exponent is the same as that of exp.
Similar to self._rescale(exp._exp) but with error checking.
"""
exp = _convert_other(exp, raiseit=True)
if context is None:
context = getcontext()
if rounding is None:
rounding = context.rounding
if self._is_special or exp._is_special:
ans = self._check_nans(exp, context)
if ans:
return ans
if exp._isinfinity() or self._isinfinity():
if exp._isinfinity() and self._isinfinity():
return Decimal(self) # if both are inf, it is OK
return context._raise_error(InvalidOperation,
'quantize with one INF')
# if we're not watching exponents, do a simple rescale
if not watchexp:
ans = self._rescale(exp._exp, rounding)
# raise Inexact and Rounded where appropriate
if ans._exp > self._exp:
context._raise_error(Rounded)
if ans != self:
context._raise_error(Inexact)
return ans
# exp._exp should be between Etiny and Emax
if not (context.Etiny() <= exp._exp <= context.Emax):
return context._raise_error(InvalidOperation,
'target exponent out of bounds in quantize')
if not self:
ans = _dec_from_triple(self._sign, '0', exp._exp)
return ans._fix(context)
self_adjusted = self.adjusted()
if self_adjusted > context.Emax:
return context._raise_error(InvalidOperation,
'exponent of quantize result too large for current context')
if self_adjusted - exp._exp + 1 > context.prec:
return context._raise_error(InvalidOperation,
'quantize result has too many digits for current context')
ans = self._rescale(exp._exp, rounding)
if ans.adjusted() > context.Emax:
return context._raise_error(InvalidOperation,
'exponent of quantize result too large for current context')
if len(ans._int) > context.prec:
return context._raise_error(InvalidOperation,
'quantize result has too many digits for current context')
# raise appropriate flags
if ans and ans.adjusted() < context.Emin:
context._raise_error(Subnormal)
if ans._exp > self._exp:
if ans != self:
context._raise_error(Inexact)
context._raise_error(Rounded)
# call to fix takes care of any necessary folddown, and
# signals Clamped if necessary
ans = ans._fix(context)
return ans
def same_quantum(self, other):
"""Return True if self and other have the same exponent; otherwise
return False.
If either operand is a special value, the following rules are used:
* return True if both operands are infinities
* return True if both operands are NaNs
* otherwise, return False.
"""
other = _convert_other(other, raiseit=True)
if self._is_special or other._is_special:
return (self.is_nan() and other.is_nan() or
self.is_infinite() and other.is_infinite())
return self._exp == other._exp
def _rescale(self, exp, rounding):
"""Rescale self so that the exponent is exp, either by padding with zeros
or by truncating digits, using the given rounding mode.
Specials are returned without change. This operation is
quiet: it raises no flags, and uses no information from the
context.
exp = exp to scale to (an integer)
rounding = rounding mode
"""
if self._is_special:
return Decimal(self)
if not self:
return _dec_from_triple(self._sign, '0', exp)
if self._exp >= exp:
# pad answer with zeros if necessary
return _dec_from_triple(self._sign,
self._int + '0'*(self._exp - exp), exp)
# too many digits; round and lose data. If self.adjusted() <
# exp-1, replace self by 10**(exp-1) before rounding
digits = len(self._int) + self._exp - exp
if digits < 0:
self = _dec_from_triple(self._sign, '1', exp-1)
digits = 0
this_function = self._pick_rounding_function[rounding]
changed = this_function(self, digits)
coeff = self._int[:digits] or '0'
if changed == 1:
coeff = str(int(coeff)+1)
return _dec_from_triple(self._sign, coeff, exp)
def _round(self, places, rounding):
"""Round a nonzero, nonspecial Decimal to a fixed number of
significant figures, using the given rounding mode.
Infinities, NaNs and zeros are returned unaltered.
This operation is quiet: it raises no flags, and uses no
information from the context.
"""
if places <= 0:
raise ValueError("argument should be at least 1 in _round")
if self._is_special or not self:
return Decimal(self)
ans = self._rescale(self.adjusted()+1-places, rounding)
# it can happen that the rescale alters the adjusted exponent;
# for example when rounding 99.97 to 3 significant figures.
# When this happens we end up with an extra 0 at the end of
# the number; a second rescale fixes this.
if ans.adjusted() != self.adjusted():
ans = ans._rescale(ans.adjusted()+1-places, rounding)
return ans
def to_integral_exact(self, rounding=None, context=None):
"""Rounds to a nearby integer.
If no rounding mode is specified, take the rounding mode from
the context. This method raises the Rounded and Inexact flags
when appropriate.
See also: to_integral_value, which does exactly the same as
this method except that it doesn't raise Inexact or Rounded.
"""
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
return Decimal(self)
if self._exp >= 0:
return Decimal(self)
if not self:
return _dec_from_triple(self._sign, '0', 0)
if context is None:
context = getcontext()
if rounding is None:
rounding = context.rounding
ans = self._rescale(0, rounding)
if ans != self:
context._raise_error(Inexact)
context._raise_error(Rounded)
return ans
def to_integral_value(self, rounding=None, context=None):
"""Rounds to the nearest integer, without raising inexact, rounded."""
if context is None:
context = getcontext()
if rounding is None:
rounding = context.rounding
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
return Decimal(self)
if self._exp >= 0:
return Decimal(self)
else:
return self._rescale(0, rounding)
# the method name changed, but we provide also the old one, for compatibility
to_integral = to_integral_value
def sqrt(self, context=None):
"""Return the square root of self."""
if context is None:
context = getcontext()
if self._is_special:
ans = self._check_nans(context=context)
if ans:
return ans
if self._isinfinity() and self._sign == 0:
return Decimal(self)
if not self:
# exponent = self._exp // 2. sqrt(-0) = -0
ans = _dec_from_triple(self._sign, '0', self._exp // 2)
return ans._fix(context)
if self._sign == 1:
return context._raise_error(InvalidOperation, 'sqrt(-x), x > 0')
# At this point self represents a positive number. Let p be
# the desired precision and express self in the form c*100**e
# with c a positive real number and e an integer, c and e
# being chosen so that 100**(p-1) <= c < 100**p. Then the
# (exact) square root of self is sqrt(c)*10**e, and 10**(p-1)
# <= sqrt(c) < 10**p, so the closest representable Decimal at
# precision p is n*10**e where n = round_half_even(sqrt(c)),
# the closest integer to sqrt(c) with the even integer chosen
# in the case of a tie.
#
# To ensure correct rounding in all cases, we use the
# following trick: we compute the square root to an extra
# place (precision p+1 instead of precision p), rounding down.
# Then, if the result is inexact and its last digit is 0 or 5,
# we increase the last digit to 1 or 6 respectively; if it's
# exact we leave the last digit alone. Now the final round to
# p places (or fewer in the case of underflow) will round
# correctly and raise the appropriate flags.
# use an extra digit of precision
prec = context.prec+1
# write argument in the form c*100**e where e = self._exp//2
# is the 'ideal' exponent, to be used if the square root is
# exactly representable. l is the number of 'digits' of c in
# base 100, so that 100**(l-1) <= c < 100**l.
op = _WorkRep(self)
e = op.exp >> 1
if op.exp & 1:
c = op.int * 10
l = (len(self._int) >> 1) + 1
else:
c = op.int
l = len(self._int)+1 >> 1
# rescale so that c has exactly prec base 100 'digits'
shift = prec-l
if shift >= 0:
c *= 100**shift
exact = True
else:
c, remainder = divmod(c, 100**-shift)
exact = not remainder
e -= shift
# find n = floor(sqrt(c)) using Newton's method
n = 10**prec
while True:
q = c//n
if n <= q:
break
else:
n = n + q >> 1
exact = exact and n*n == c
if exact:
# result is exact; rescale to use ideal exponent e
if shift >= 0:
# assert n % 10**shift == 0
n //= 10**shift
else:
n *= 10**-shift
e += shift
else:
# result is not exact; fix last digit as described above
if n % 5 == 0:
n += 1
ans = _dec_from_triple(0, str(n), e)
# round, and fit to current context
context = context._shallow_copy()
rounding = context._set_rounding(ROUND_HALF_EVEN)
ans = ans._fix(context)
context.rounding = rounding
return ans
def max(self, other, context=None):
"""Returns the larger value.
Like max(self, other) except if one is not a number, returns
NaN (and signals if one is sNaN). Also rounds.
"""
other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the
# number is always returned
sn = self._isnan()
on = other._isnan()
if sn or on:
if on == 1 and sn == 0:
return self._fix(context)
if sn == 1 and on == 0:
return other._fix(context)
return self._check_nans(other, context)
c = self._cmp(other)
if c == 0:
# If both operands are finite and equal in numerical value
# then an ordering is applied:
#
# If the signs differ then max returns the operand with the
# positive sign and min returns the operand with the negative sign
#
# If the signs are the same then the exponent is used to select
# the result. This is exactly the ordering used in compare_total.
c = self.compare_total(other)
if c == -1:
ans = other
else:
ans = self
return ans._fix(context)
def min(self, other, context=None):
"""Returns the smaller value.
Like min(self, other) except if one is not a number, returns
NaN (and signals if one is sNaN). Also rounds.
"""
other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the
# number is always returned
sn = self._isnan()
on = other._isnan()
if sn or on:
if on == 1 and sn == 0:
return self._fix(context)
if sn == 1 and on == 0:
return other._fix(context)
return self._check_nans(other, context)
c = self._cmp(other)
if c == 0:
c = self.compare_total(other)
if c == -1:
ans = self
else:
ans = other
return ans._fix(context)
def _isinteger(self):
"""Returns whether self is an integer"""
if self._is_special:
return False
if self._exp >= 0:
return True
rest = self._int[self._exp:]
return rest == '0'*len(rest)
def _iseven(self):
"""Returns True if self is even. Assumes self is an integer."""
if not self or self._exp > 0:
return True
return self._int[-1+self._exp] in '02468'
def adjusted(self):
"""Return the adjusted exponent of self"""
try:
return self._exp + len(self._int) - 1
# If NaN or Infinity, self._exp is string
except TypeError:
return 0
def canonical(self, context=None):
"""Returns the same Decimal object.
As we do not have different encodings for the same number, the
received object already is in its canonical form.
"""
return self
def compare_signal(self, other, context=None):
"""Compares self to the other operand numerically.
It's pretty much like compare(), but all NaNs signal, with signaling
NaNs taking precedence over quiet NaNs.
"""
other = _convert_other(other, raiseit = True)
ans = self._compare_check_nans(other, context)
if ans:
return ans
return self.compare(other, context=context)
def compare_total(self, other):
"""Compares self to other using the abstract representations.
This is not like the standard compare, which use their numerical
value. Note that a total ordering is defined for all possible abstract
representations.
"""
other = _convert_other(other, raiseit=True)
# if one is negative and the other is positive, it's easy
if self._sign and not other._sign:
return _NegativeOne
if not self._sign and other._sign:
return _One
sign = self._sign
# let's handle both NaN types
self_nan = self._isnan()
other_nan = other._isnan()
if self_nan or other_nan:
if self_nan == other_nan:
# compare payloads as though they're integers
self_key = len(self._int), self._int
other_key = len(other._int), other._int
if self_key < other_key:
if sign:
return _One
else:
return _NegativeOne
if self_key > other_key:
if sign:
return _NegativeOne
else:
return _One
return _Zero
if sign:
if self_nan == 1:
return _NegativeOne
if other_nan == 1:
return _One
if self_nan == 2:
return _NegativeOne
if other_nan == 2:
return _One
else:
if self_nan == 1:
return _One
if other_nan == 1:
return _NegativeOne
if self_nan == 2:
return _One
if other_nan == 2:
return _NegativeOne
if self < other:
return _NegativeOne
if self > other:
return _One
if self._exp < other._exp:
if sign:
return _One
else:
return _NegativeOne
if self._exp > other._exp:
if sign:
return _NegativeOne
else:
return _One
return _Zero
def compare_total_mag(self, other):
"""Compares self to other using abstract repr., ignoring sign.
Like compare_total, but with operand's sign ignored and assumed to be 0.
"""
other = _convert_other(other, raiseit=True)
s = self.copy_abs()
o = other.copy_abs()
return s.compare_total(o)
def copy_abs(self):
"""Returns a copy with the sign set to 0. """
return _dec_from_triple(0, self._int, self._exp, self._is_special)
def copy_negate(self):
"""Returns a copy with the sign inverted."""
if self._sign:
return _dec_from_triple(0, self._int, self._exp, self._is_special)
else:
return _dec_from_triple(1, self._int, self._exp, self._is_special)
def copy_sign(self, other):
"""Returns self with the sign of other."""
other = _convert_other(other, raiseit=True)
return _dec_from_triple(other._sign, self._int,
self._exp, self._is_special)
def exp(self, context=None):
"""Returns e ** self."""
if context is None:
context = getcontext()
# exp(NaN) = NaN
ans = self._check_nans(context=context)
if ans:
return ans
# exp(-Infinity) = 0
if self._isinfinity() == -1:
return _Zero
# exp(0) = 1
if not self:
return _One
# exp(Infinity) = Infinity
if self._isinfinity() == 1:
return Decimal(self)
# the result is now guaranteed to be inexact (the true
# mathematical result is transcendental). There's no need to
# raise Rounded and Inexact here---they'll always be raised as
# a result of the call to _fix.
p = context.prec
adj = self.adjusted()
# we only need to do any computation for quite a small range
# of adjusted exponents---for example, -29 <= adj <= 10 for
# the default context. For smaller exponent the result is
# indistinguishable from 1 at the given precision, while for
# larger exponent the result either overflows or underflows.
if self._sign == 0 and adj > len(str((context.Emax+1)*3)):
# overflow
ans = _dec_from_triple(0, '1', context.Emax+1)
elif self._sign == 1 and adj > len(str((-context.Etiny()+1)*3)):
# underflow to 0
ans = _dec_from_triple(0, '1', context.Etiny()-1)
elif self._sign == 0 and adj < -p:
# p+1 digits; final round will raise correct flags
ans = _dec_from_triple(0, '1' + '0'*(p-1) + '1', -p)
elif self._sign == 1 and adj < -p-1:
# p+1 digits; final round will raise correct flags
ans = _dec_from_triple(0, '9'*(p+1), -p-1)
# general case
else:
op = _WorkRep(self)
c, e = op.int, op.exp
if op.sign == 1:
c = -c
# compute correctly rounded result: increase precision by
# 3 digits at a time until we get an unambiguously
# roundable result
extra = 3
while True:
coeff, exp = _dexp(c, e, p+extra)
if coeff % (5*10**(len(str(coeff))-p-1)):
break
extra += 3
ans = _dec_from_triple(0, str(coeff), exp)
# at this stage, ans should round correctly with *any*
# rounding mode, not just with ROUND_HALF_EVEN
context = context._shallow_copy()
rounding = context._set_rounding(ROUND_HALF_EVEN)
ans = ans._fix(context)
context.rounding = rounding
return ans
def is_canonical(self):
"""Return True if self is canonical; otherwise return False.
Currently, the encoding of a Decimal instance is always
canonical, so this method returns True for any Decimal.
"""
return True
def is_finite(self):
"""Return True if self is finite; otherwise return False.
A Decimal instance is considered finite if it is neither
infinite nor a NaN.
"""
return not self._is_special
def is_infinite(self):
"""Return True if self is infinite; otherwise return False."""
return self._exp == 'F'
def is_nan(self):
"""Return True if self is a qNaN or sNaN; otherwise return False."""
return self._exp in ('n', 'N')
def is_normal(self, context=None):
"""Return True if self is a normal number; otherwise return False."""
if self._is_special or not self:
return False
if context is None:
context = getcontext()
return context.Emin <= self.adjusted()
def is_qnan(self):
"""Return True if self is a quiet NaN; otherwise return False."""
return self._exp == 'n'
def is_signed(self):
"""Return True if self is negative; otherwise return False."""
return self._sign == 1
def is_snan(self):
"""Return True if self is a signaling NaN; otherwise return False."""
return self._exp == 'N'
def is_subnormal(self, context=None):
"""Return True if self is subnormal; otherwise return False."""
if self._is_special or not self:
return False
if context is None:
context = getcontext()
return self.adjusted() < context.Emin
def is_zero(self):
"""Return True if self is a zero; otherwise return False."""
return not self._is_special and self._int == '0'
def _ln_exp_bound(self):
"""Compute a lower bound for the adjusted exponent of self.ln().
In other words, compute r such that self.ln() >= 10**r. Assumes
that self is finite and positive and that self != 1.
"""
# for 0.1 <= x <= 10 we use the inequalities 1-1/x <= ln(x) <= x-1
adj = self._exp + len(self._int) - 1
if adj >= 1:
# argument >= 10; we use 23/10 = 2.3 as a lower bound for ln(10)
return len(str(adj*23//10)) - 1
if adj <= -2:
# argument <= 0.1
return len(str((-1-adj)*23//10)) - 1
op = _WorkRep(self)
c, e = op.int, op.exp
if adj == 0:
# 1 < self < 10
num = str(c-10**-e)
den = str(c)
return len(num) - len(den) - (num < den)
# adj == -1, 0.1 <= self < 1
return e + len(str(10**-e - c)) - 1
def ln(self, context=None):
"""Returns the natural (base e) logarithm of self."""
if context is None:
context = getcontext()
# ln(NaN) = NaN
ans = self._check_nans(context=context)
if ans:
return ans
# ln(0.0) == -Infinity
if not self:
return _NegativeInfinity
# ln(Infinity) = Infinity
if self._isinfinity() == 1:
return _Infinity
# ln(1.0) == 0.0
if self == _One:
return _Zero
# ln(negative) raises InvalidOperation
if self._sign == 1:
return context._raise_error(InvalidOperation,
'ln of a negative value')
# result is irrational, so necessarily inexact
op = _WorkRep(self)
c, e = op.int, op.exp
p = context.prec
# correctly rounded result: repeatedly increase precision by 3
# until we get an unambiguously roundable result
places = p - self._ln_exp_bound() + 2 # at least p+3 places
while True:
coeff = _dlog(c, e, places)
# assert len(str(abs(coeff)))-p >= 1
if coeff % (5*10**(len(str(abs(coeff)))-p-1)):
break
places += 3
ans = _dec_from_triple(int(coeff<0), str(abs(coeff)), -places)
context = context._shallow_copy()
rounding = context._set_rounding(ROUND_HALF_EVEN)
ans = ans._fix(context)
context.rounding = rounding
return ans
def _log10_exp_bound(self):
"""Compute a lower bound for the adjusted exponent of self.log10().
In other words, find r such that self.log10() >= 10**r.
Assumes that self is finite and positive and that self != 1.
"""
# For x >= 10 or x < 0.1 we only need a bound on the integer
# part of log10(self), and this comes directly from the
# exponent of x. For 0.1 <= x <= 10 we use the inequalities
# 1-1/x <= log(x) <= x-1. If x > 1 we have |log10(x)| >
# (1-1/x)/2.31 > 0. If x < 1 then |log10(x)| > (1-x)/2.31 > 0
adj = self._exp + len(self._int) - 1
if adj >= 1:
# self >= 10
return len(str(adj))-1
if adj <= -2:
# self < 0.1
return len(str(-1-adj))-1
op = _WorkRep(self)
c, e = op.int, op.exp
if adj == 0:
# 1 < self < 10
num = str(c-10**-e)
den = str(231*c)
return len(num) - len(den) - (num < den) + 2
# adj == -1, 0.1 <= self < 1
num = str(10**-e-c)
return len(num) + e - (num < "231") - 1
def log10(self, context=None):
"""Returns the base 10 logarithm of self."""
if context is None:
context = getcontext()
# log10(NaN) = NaN
ans = self._check_nans(context=context)
if ans:
return ans
# log10(0.0) == -Infinity
if not self:
return _NegativeInfinity
# log10(Infinity) = Infinity
if self._isinfinity() == 1:
return _Infinity
# log10(negative or -Infinity) raises InvalidOperation
if self._sign == 1:
return context._raise_error(InvalidOperation,
'log10 of a negative value')
# log10(10**n) = n
if self._int[0] == '1' and self._int[1:] == '0'*(len(self._int) - 1):
# answer may need rounding
ans = Decimal(self._exp + len(self._int) - 1)
else:
# result is irrational, so necessarily inexact
op = _WorkRep(self)
c, e = op.int, op.exp
p = context.prec
# correctly rounded result: repeatedly increase precision
# until result is unambiguously roundable
places = p-self._log10_exp_bound()+2
while True:
coeff = _dlog10(c, e, places)
# assert len(str(abs(coeff)))-p >= 1
if coeff % (5*10**(len(str(abs(coeff)))-p-1)):
break
places += 3
ans = _dec_from_triple(int(coeff<0), str(abs(coeff)), -places)
context = context._shallow_copy()
rounding = context._set_rounding(ROUND_HALF_EVEN)
ans = ans._fix(context)
context.rounding = rounding
return ans
def logb(self, context=None):
""" Returns the exponent of the magnitude of self's MSD.
The result is the integer which is the exponent of the magnitude
of the most significant digit of self (as though it were truncated
to a single digit while maintaining the value of that digit and
without limiting the resulting exponent).
"""
# logb(NaN) = NaN
ans = self._check_nans(context=context)
if ans:
return ans
if context is None:
context = getcontext()
# logb(+/-Inf) = +Inf
if self._isinfinity():
return _Infinity
# logb(0) = -Inf, DivisionByZero
if not self:
return context._raise_error(DivisionByZero, 'logb(0)', 1)
# otherwise, simply return the adjusted exponent of self, as a
# Decimal. Note that no attempt is made to fit the result
# into the current context.
ans = Decimal(self.adjusted())
return ans._fix(context)
def _islogical(self):
"""Return True if self is a logical operand.
For being logical, it must be a finite number with a sign of 0,
an exponent of 0, and a coefficient whose digits must all be
either 0 or 1.
"""
if self._sign != 0 or self._exp != 0:
return False
for dig in self._int:
if dig not in '01':
return False
return True
def _fill_logical(self, context, opa, opb):
dif = context.prec - len(opa)
if dif > 0:
opa = '0'*dif + opa
elif dif < 0:
opa = opa[-context.prec:]
dif = context.prec - len(opb)
if dif > 0:
opb = '0'*dif + opb
elif dif < 0:
opb = opb[-context.prec:]
return opa, opb
def logical_and(self, other, context=None):
"""Applies an 'and' operation between self and other's digits."""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
# fill to context.prec
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
result = "".join([str(int(a)&int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def logical_invert(self, context=None):
"""Invert all its digits."""
if context is None:
context = getcontext()
return self.logical_xor(_dec_from_triple(0,'1'*context.prec,0),
context)
def logical_or(self, other, context=None):
"""Applies an 'or' operation between self and other's digits."""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
# fill to context.prec
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
result = "".join([str(int(a)|int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def logical_xor(self, other, context=None):
"""Applies an 'xor' operation between self and other's digits."""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
if not self._islogical() or not other._islogical():
return context._raise_error(InvalidOperation)
# fill to context.prec
(opa, opb) = self._fill_logical(context, self._int, other._int)
# make the operation, and clean starting zeroes
result = "".join([str(int(a)^int(b)) for a,b in zip(opa,opb)])
return _dec_from_triple(0, result.lstrip('0') or '0', 0)
def max_mag(self, other, context=None):
"""Compares the values numerically with their sign ignored."""
other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the
# number is always returned
sn = self._isnan()
on = other._isnan()
if sn or on:
if on == 1 and sn == 0:
return self._fix(context)
if sn == 1 and on == 0:
return other._fix(context)
return self._check_nans(other, context)
c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
if c == -1:
ans = other
else:
ans = self
return ans._fix(context)
def min_mag(self, other, context=None):
"""Compares the values numerically with their sign ignored."""
other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
if self._is_special or other._is_special:
# If one operand is a quiet NaN and the other is number, then the
# number is always returned
sn = self._isnan()
on = other._isnan()
if sn or on:
if on == 1 and sn == 0:
return self._fix(context)
if sn == 1 and on == 0:
return other._fix(context)
return self._check_nans(other, context)
c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
if c == -1:
ans = self
else:
ans = other
return ans._fix(context)
def next_minus(self, context=None):
"""Returns the largest representable number smaller than itself."""
if context is None:
context = getcontext()
ans = self._check_nans(context=context)
if ans:
return ans
if self._isinfinity() == -1:
return _NegativeInfinity
if self._isinfinity() == 1:
return _dec_from_triple(0, '9'*context.prec, context.Etop())
context = context.copy()
context._set_rounding(ROUND_FLOOR)
context._ignore_all_flags()
new_self = self._fix(context)
if new_self != self:
return new_self
return self.__sub__(_dec_from_triple(0, '1', context.Etiny()-1),
context)
def next_plus(self, context=None):
"""Returns the smallest representable number larger than itself."""
if context is None:
context = getcontext()
ans = self._check_nans(context=context)
if ans:
return ans
if self._isinfinity() == 1:
return _Infinity
if self._isinfinity() == -1:
return _dec_from_triple(1, '9'*context.prec, context.Etop())
context = context.copy()
context._set_rounding(ROUND_CEILING)
context._ignore_all_flags()
new_self = self._fix(context)
if new_self != self:
return new_self
return self.__add__(_dec_from_triple(0, '1', context.Etiny()-1),
context)
def next_toward(self, other, context=None):
"""Returns the number closest to self, in the direction towards other.
The result is the closest representable number to self
(excluding self) that is in the direction towards other,
unless both have the same value. If the two operands are
numerically equal, then the result is a copy of self with the
sign set to be the same as the sign of other.
"""
other = _convert_other(other, raiseit=True)
if context is None:
context = getcontext()
ans = self._check_nans(other, context)
if ans:
return ans
comparison = self._cmp(other)
if comparison == 0:
return self.copy_sign(other)
if comparison == -1:
ans = self.next_plus(context)
else: # comparison == 1
ans = self.next_minus(context)
# decide which flags to raise using value of ans
if ans._isinfinity():
context._raise_error(Overflow,
'Infinite result from next_toward',
ans._sign)
context._raise_error(Inexact)
context._raise_error(Rounded)
elif ans.adjusted() < context.Emin:
context._raise_error(Underflow)
context._raise_error(Subnormal)
context._raise_error(Inexact)
context._raise_error(Rounded)
# if precision == 1 then we don't raise Clamped for a
# result 0E-Etiny.
if not ans:
context._raise_error(Clamped)
return ans
def number_class(self, context=None):
"""Returns an indication of the class of self.
The class is one of the following strings:
sNaN
NaN
-Infinity
-Normal
-Subnormal
-Zero
+Zero
+Subnormal
+Normal
+Infinity
"""
if self.is_snan():
return "sNaN"
if self.is_qnan():
return "NaN"
inf = self._isinfinity()
if inf == 1:
return "+Infinity"
if inf == -1:
return "-Infinity"
if self.is_zero():
if self._sign:
return "-Zero"
else:
return "+Zero"
if context is None:
context = getcontext()
if self.is_subnormal(context=context):
if self._sign:
return "-Subnormal"
else:
return "+Subnormal"
# just a normal, regular, boring number, :)
if self._sign:
return "-Normal"
else:
return "+Normal"
def radix(self):
"""Just returns 10, as this is Decimal, :)"""
return Decimal(10)
def rotate(self, other, context=None):
"""Returns a rotated copy of self, value-of-other times."""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
ans = self._check_nans(other, context)
if ans:
return ans
if other._exp != 0:
return context._raise_error(InvalidOperation)
if not (-context.prec <= int(other) <= context.prec):
return context._raise_error(InvalidOperation)
if self._isinfinity():
return Decimal(self)
# get values, pad if necessary
torot = int(other)
rotdig = self._int
topad = context.prec - len(rotdig)
if topad > 0:
rotdig = '0'*topad + rotdig
elif topad < 0:
rotdig = rotdig[-topad:]
# let's rotate!
rotated = rotdig[torot:] + rotdig[:torot]
return _dec_from_triple(self._sign,
rotated.lstrip('0') or '0', self._exp)
def scaleb(self, other, context=None):
"""Returns self operand after adding the second value to its exp."""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
ans = self._check_nans(other, context)
if ans:
return ans
if other._exp != 0:
return context._raise_error(InvalidOperation)
liminf = -2 * (context.Emax + context.prec)
limsup = 2 * (context.Emax + context.prec)
if not (liminf <= int(other) <= limsup):
return context._raise_error(InvalidOperation)
if self._isinfinity():
return Decimal(self)
d = _dec_from_triple(self._sign, self._int, self._exp + int(other))
d = d._fix(context)
return d
def shift(self, other, context=None):
"""Returns a shifted copy of self, value-of-other times."""
if context is None:
context = getcontext()
other = _convert_other(other, raiseit=True)
ans = self._check_nans(other, context)
if ans:
return ans
if other._exp != 0:
return context._raise_error(InvalidOperation)
if not (-context.prec <= int(other) <= context.prec):
return context._raise_error(InvalidOperation)
if self._isinfinity():
return Decimal(self)
# get values, pad if necessary
torot = int(other)
rotdig = self._int
topad = context.prec - len(rotdig)
if topad > 0:
rotdig = '0'*topad + rotdig
elif topad < 0:
rotdig = rotdig[-topad:]
# let's shift!
if torot < 0:
shifted = rotdig[:torot]
else:
shifted = rotdig + '0'*torot
shifted = shifted[-context.prec:]
return _dec_from_triple(self._sign,
shifted.lstrip('0') or '0', self._exp)
# Support for pickling, copy, and deepcopy
def __reduce__(self):
return (self.__class__, (str(self),))
def __copy__(self):
if type(self) is Decimal:
return self # I'm immutable; therefore I am my own clone
return self.__class__(str(self))
def __deepcopy__(self, memo):
if type(self) is Decimal:
return self # My components are also immutable
return self.__class__(str(self))
# PEP 3101 support. the _localeconv keyword argument should be
# considered private: it's provided for ease of testing only.
def __format__(self, specifier, context=None, _localeconv=None):
"""Format a Decimal instance according to the given specifier.
The specifier should be a standard format specifier, with the
form described in PEP 3101. Formatting types 'e', 'E', 'f',
'F', 'g', 'G', 'n' and '%' are supported. If the formatting
type is omitted it defaults to 'g' or 'G', depending on the
value of context.capitals.
"""
# Note: PEP 3101 says that if the type is not present then
# there should be at least one digit after the decimal point.
# We take the liberty of ignoring this requirement for
# Decimal---it's presumably there to make sure that
# format(float, '') behaves similarly to str(float).
if context is None:
context = getcontext()
spec = _parse_format_specifier(specifier, _localeconv=_localeconv)
# special values don't care about the type or precision
if self._is_special:
sign = _format_sign(self._sign, spec)
body = str(self.copy_abs())
if spec['type'] == '%':
body += '%'
return _format_align(sign, body, spec)
# a type of None defaults to 'g' or 'G', depending on context
if spec['type'] is None:
spec['type'] = ['g', 'G'][context.capitals]
# if type is '%', adjust exponent of self accordingly
if spec['type'] == '%':
self = _dec_from_triple(self._sign, self._int, self._exp+2)
# round if necessary, taking rounding mode from the context
rounding = context.rounding
precision = spec['precision']
if precision is not None:
if spec['type'] in 'eE':
self = self._round(precision+1, rounding)
elif spec['type'] in 'fF%':
self = self._rescale(-precision, rounding)
elif spec['type'] in 'gG' and len(self._int) > precision:
self = self._round(precision, rounding)
# special case: zeros with a positive exponent can't be
# represented in fixed point; rescale them to 0e0.
if not self and self._exp > 0 and spec['type'] in 'fF%':
self = self._rescale(0, rounding)
# figure out placement of the decimal point
leftdigits = self._exp + len(self._int)
if spec['type'] in 'eE':
if not self and precision is not None:
dotplace = 1 - precision
else:
dotplace = 1
elif spec['type'] in 'fF%':
dotplace = leftdigits
elif spec['type'] in 'gG':
if self._exp <= 0 and leftdigits > -6:
dotplace = leftdigits
else:
dotplace = 1
# find digits before and after decimal point, and get exponent
if dotplace < 0:
intpart = '0'
fracpart = '0'*(-dotplace) + self._int
elif dotplace > len(self._int):
intpart = self._int + '0'*(dotplace-len(self._int))
fracpart = ''
else:
intpart = self._int[:dotplace] or '0'
fracpart = self._int[dotplace:]
exp = leftdigits-dotplace
# done with the decimal-specific stuff; hand over the rest
# of the formatting to the _format_number function
return _format_number(self._sign, intpart, fracpart, exp, spec)
def _dec_from_triple(sign, coefficient, exponent, special=False):
"""Create a decimal instance directly, without any validation,
normalization (e.g. removal of leading zeros) or argument
conversion.
This function is for *internal use only*.
"""
self = object.__new__(Decimal)
self._sign = sign
self._int = coefficient
self._exp = exponent
self._is_special = special
return self
# Register Decimal as a kind of Number (an abstract base class).
# However, do not register it as Real (because Decimals are not
# interoperable with floats).
_numbers.Number.register(Decimal)
##### Context class #######################################################
class _ContextManager(object):
"""Context manager class to support localcontext().
Sets a copy of the supplied context in __enter__() and restores
the previous decimal context in __exit__()
"""
def __init__(self, new_context):
self.new_context = new_context.copy()
def __enter__(self):
self.saved_context = getcontext()
setcontext(self.new_context)
return self.new_context
def __exit__(self, t, v, tb):
setcontext(self.saved_context)
class Context(object):
"""Contains the context for a Decimal instance.
Contains:
prec - precision (for use in rounding, division, square roots..)
rounding - rounding type (how you round)
traps - If traps[exception] = 1, then the exception is
raised when it is caused. Otherwise, a value is
substituted in.
flags - When an exception is caused, flags[exception] is set.
(Whether or not the trap_enabler is set)
Should be reset by user of Decimal instance.
Emin - Minimum exponent
Emax - Maximum exponent
capitals - If 1, 1*10^1 is printed as 1E+1.
If 0, printed as 1e1
_clamp - If 1, change exponents if too high (Default 0)
"""
def __init__(self, prec=None, rounding=None,
traps=None, flags=None,
Emin=None, Emax=None,
capitals=None, _clamp=0,
_ignored_flags=None):
# Set defaults; for everything except flags and _ignored_flags,
# inherit from DefaultContext.
try:
dc = DefaultContext
except NameError:
pass
self.prec = prec if prec is not None else dc.prec
self.rounding = rounding if rounding is not None else dc.rounding
self.Emin = Emin if Emin is not None else dc.Emin
self.Emax = Emax if Emax is not None else dc.Emax
self.capitals = capitals if capitals is not None else dc.capitals
self._clamp = _clamp if _clamp is not None else dc._clamp
if _ignored_flags is None:
self._ignored_flags = []
else:
self._ignored_flags = _ignored_flags
if traps is None:
self.traps = dc.traps.copy()
elif not isinstance(traps, dict):
self.traps = dict((s, int(s in traps)) for s in _signals)
else:
self.traps = traps
if flags is None:
self.flags = dict.fromkeys(_signals, 0)
elif not isinstance(flags, dict):
self.flags = dict((s, int(s in flags)) for s in _signals)
else:
self.flags = flags
def __repr__(self):
"""Show the current context."""
s = []
s.append('Context(prec=%(prec)d, rounding=%(rounding)s, '
'Emin=%(Emin)d, Emax=%(Emax)d, capitals=%(capitals)d'
% vars(self))
names = [f.__name__ for f, v in self.flags.items() if v]
s.append('flags=[' + ', '.join(names) + ']')
names = [t.__name__ for t, v in self.traps.items() if v]
s.append('traps=[' + ', '.join(names) + ']')
return ', '.join(s) + ')'
def clear_flags(self):
"""Reset all flags to zero"""
for flag in self.flags:
self.flags[flag] = 0
def _shallow_copy(self):
"""Returns a shallow copy from self."""
nc = Context(self.prec, self.rounding, self.traps,
self.flags, self.Emin, self.Emax,
self.capitals, self._clamp, self._ignored_flags)
return nc
def copy(self):
"""Returns a deep copy from self."""
nc = Context(self.prec, self.rounding, self.traps.copy(),
self.flags.copy(), self.Emin, self.Emax,
self.capitals, self._clamp, self._ignored_flags)
return nc
__copy__ = copy
def _raise_error(self, condition, explanation = None, *args):
"""Handles an error
If the flag is in _ignored_flags, returns the default response.
Otherwise, it sets the flag, then, if the corresponding
trap_enabler is set, it reraises the exception. Otherwise, it returns
the default value after setting the flag.
"""
error = _condition_map.get(condition, condition)
if error in self._ignored_flags:
# Don't touch the flag
return error().handle(self, *args)
self.flags[error] = 1
if not self.traps[error]:
# The errors define how to handle themselves.
return condition().handle(self, *args)
# Errors should only be risked on copies of the context
# self._ignored_flags = []
raise error(explanation)
def _ignore_all_flags(self):
"""Ignore all flags, if they are raised"""
return self._ignore_flags(*_signals)
def _ignore_flags(self, *flags):
"""Ignore the flags, if they are raised"""
# Do not mutate-- This way, copies of a context leave the original
# alone.
self._ignored_flags = (self._ignored_flags + list(flags))
return list(flags)
def _regard_flags(self, *flags):
"""Stop ignoring the flags, if they are raised"""
if flags and isinstance(flags[0], (tuple,list)):
flags = flags[0]
for flag in flags:
self._ignored_flags.remove(flag)
# We inherit object.__hash__, so we must deny this explicitly
__hash__ = None
def Etiny(self):
"""Returns Etiny (= Emin - prec + 1)"""
return int(self.Emin - self.prec + 1)
def Etop(self):
"""Returns maximum exponent (= Emax - prec + 1)"""
return int(self.Emax - self.prec + 1)
def _set_rounding(self, type):
"""Sets the rounding type.
Sets the rounding type, and returns the current (previous)
rounding type. Often used like:
context = context.copy()
# so you don't change the calling context
# if an error occurs in the middle.
rounding = context._set_rounding(ROUND_UP)
val = self.__sub__(other, context=context)
context._set_rounding(rounding)
This will make it round up for that operation.
"""
rounding = self.rounding
self.rounding= type
return rounding
def create_decimal(self, num='0'):
"""Creates a new Decimal instance but using self as context.
This method implements the to-number operation of the
IBM Decimal specification."""
if isinstance(num, basestring) and num != num.strip():
return self._raise_error(ConversionSyntax,
"no trailing or leading whitespace is "
"permitted.")
d = Decimal(num, context=self)
if d._isnan() and len(d._int) > self.prec - self._clamp:
return self._raise_error(ConversionSyntax,
"diagnostic info too long in NaN")
return d._fix(self)
def create_decimal_from_float(self, f):
"""Creates a new Decimal instance from a float but rounding using self
as the context.
>>> context = Context(prec=5, rounding=ROUND_DOWN)
>>> context.create_decimal_from_float(3.1415926535897932)
Decimal('3.1415')
>>> context = Context(prec=5, traps=[Inexact])
>>> context.create_decimal_from_float(3.1415926535897932)
Traceback (most recent call last):
...
Inexact: None
"""
d = Decimal.from_float(f) # An exact conversion
return d._fix(self) # Apply the context rounding
# Methods
def abs(self, a):
"""Returns the absolute value of the operand.
If the operand is negative, the result is the same as using the minus
operation on the operand. Otherwise, the result is the same as using
the plus operation on the operand.
>>> ExtendedContext.abs(Decimal('2.1'))
Decimal('2.1')
>>> ExtendedContext.abs(Decimal('-100'))
Decimal('100')
>>> ExtendedContext.abs(Decimal('101.5'))
Decimal('101.5')
>>> ExtendedContext.abs(Decimal('-101.5'))
Decimal('101.5')
>>> ExtendedContext.abs(-1)
Decimal('1')
"""
a = _convert_other(a, raiseit=True)
return a.__abs__(context=self)
def add(self, a, b):
"""Return the sum of the two operands.
>>> ExtendedContext.add(Decimal('12'), Decimal('7.00'))
Decimal('19.00')
>>> ExtendedContext.add(Decimal('1E+2'), Decimal('1.01E+4'))
Decimal('1.02E+4')
>>> ExtendedContext.add(1, Decimal(2))
Decimal('3')
>>> ExtendedContext.add(Decimal(8), 5)
Decimal('13')
>>> ExtendedContext.add(5, 5)
Decimal('10')
"""
a = _convert_other(a, raiseit=True)
r = a.__add__(b, context=self)
if r is NotImplemented:
raise TypeError("Unable to convert %s to Decimal" % b)
else:
return r
def _apply(self, a):
return str(a._fix(self))
def canonical(self, a):
"""Returns the same Decimal object.
As we do not have different encodings for the same number, the
received object already is in its canonical form.
>>> ExtendedContext.canonical(Decimal('2.50'))
Decimal('2.50')
"""
return a.canonical(context=self)
def compare(self, a, b):
"""Compares values numerically.
If the signs of the operands differ, a value representing each operand
('-1' if the operand is less than zero, '0' if the operand is zero or
negative zero, or '1' if the operand is greater than zero) is used in
place of that operand for the comparison instead of the actual
operand.
The comparison is then effected by subtracting the second operand from
the first and then returning a value according to the result of the
subtraction: '-1' if the result is less than zero, '0' if the result is
zero or negative zero, or '1' if the result is greater than zero.
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('3'))
Decimal('-1')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('2.1'))
Decimal('0')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('2.10'))
Decimal('0')
>>> ExtendedContext.compare(Decimal('3'), Decimal('2.1'))
Decimal('1')
>>> ExtendedContext.compare(Decimal('2.1'), Decimal('-3'))
Decimal('1')
>>> ExtendedContext.compare(Decimal('-3'), Decimal('2.1'))
Decimal('-1')
>>> ExtendedContext.compare(1, 2)
Decimal('-1')
>>> ExtendedContext.compare(Decimal(1), 2)
Decimal('-1')
>>> ExtendedContext.compare(1, Decimal(2))
Decimal('-1')
"""
a = _convert_other(a, raiseit=True)
return a.compare(b, context=self)
def compare_signal(self, a, b):
"""Compares the values of the two operands numerically.
It's pretty much like compare(), but all NaNs signal, with signaling
NaNs taking precedence over quiet NaNs.
>>> c = ExtendedContext
>>> c.compare_signal(Decimal('2.1'), Decimal('3'))
Decimal('-1')
>>> c.compare_signal(Decimal('2.1'), Decimal('2.1'))
Decimal('0')
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.compare_signal(Decimal('NaN'), Decimal('2.1'))
Decimal('NaN')
>>> print c.flags[InvalidOperation]
1
>>> c.flags[InvalidOperation] = 0
>>> print c.flags[InvalidOperation]
0
>>> c.compare_signal(Decimal('sNaN'), Decimal('2.1'))
Decimal('NaN')
>>> print c.flags[InvalidOperation]
1
>>> c.compare_signal(-1, 2)
Decimal('-1')
>>> c.compare_signal(Decimal(-1), 2)
Decimal('-1')
>>> c.compare_signal(-1, Decimal(2))
Decimal('-1')
"""
a = _convert_other(a, raiseit=True)
return a.compare_signal(b, context=self)
def compare_total(self, a, b):
"""Compares two operands using their abstract representation.
This is not like the standard compare, which use their numerical
value. Note that a total ordering is defined for all possible abstract
representations.
>>> ExtendedContext.compare_total(Decimal('12.73'), Decimal('127.9'))
Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('-127'), Decimal('12'))
Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('12.30'), Decimal('12.3'))
Decimal('-1')
>>> ExtendedContext.compare_total(Decimal('12.30'), Decimal('12.30'))
Decimal('0')
>>> ExtendedContext.compare_total(Decimal('12.3'), Decimal('12.300'))
Decimal('1')
>>> ExtendedContext.compare_total(Decimal('12.3'), Decimal('NaN'))
Decimal('-1')
>>> ExtendedContext.compare_total(1, 2)
Decimal('-1')
>>> ExtendedContext.compare_total(Decimal(1), 2)
Decimal('-1')
>>> ExtendedContext.compare_total(1, Decimal(2))
Decimal('-1')
"""
a = _convert_other(a, raiseit=True)
return a.compare_total(b)
def compare_total_mag(self, a, b):
"""Compares two operands using their abstract representation ignoring sign.
Like compare_total, but with operand's sign ignored and assumed to be 0.
"""
a = _convert_other(a, raiseit=True)
return a.compare_total_mag(b)
def copy_abs(self, a):
"""Returns a copy of the operand with the sign set to 0.
>>> ExtendedContext.copy_abs(Decimal('2.1'))
Decimal('2.1')
>>> ExtendedContext.copy_abs(Decimal('-100'))
Decimal('100')
>>> ExtendedContext.copy_abs(-1)
Decimal('1')
"""
a = _convert_other(a, raiseit=True)
return a.copy_abs()
def copy_decimal(self, a):
"""Returns a copy of the decimal object.
>>> ExtendedContext.copy_decimal(Decimal('2.1'))
Decimal('2.1')
>>> ExtendedContext.copy_decimal(Decimal('-1.00'))
Decimal('-1.00')
>>> ExtendedContext.copy_decimal(1)
Decimal('1')
"""
a = _convert_other(a, raiseit=True)
return Decimal(a)
def copy_negate(self, a):
"""Returns a copy of the operand with the sign inverted.
>>> ExtendedContext.copy_negate(Decimal('101.5'))
Decimal('-101.5')
>>> ExtendedContext.copy_negate(Decimal('-101.5'))
Decimal('101.5')
>>> ExtendedContext.copy_negate(1)
Decimal('-1')
"""
a = _convert_other(a, raiseit=True)
return a.copy_negate()
def copy_sign(self, a, b):
"""Copies the second operand's sign to the first one.
In detail, it returns a copy of the first operand with the sign
equal to the sign of the second operand.
>>> ExtendedContext.copy_sign(Decimal( '1.50'), Decimal('7.33'))
Decimal('1.50')
>>> ExtendedContext.copy_sign(Decimal('-1.50'), Decimal('7.33'))
Decimal('1.50')
>>> ExtendedContext.copy_sign(Decimal( '1.50'), Decimal('-7.33'))
Decimal('-1.50')
>>> ExtendedContext.copy_sign(Decimal('-1.50'), Decimal('-7.33'))
Decimal('-1.50')
>>> ExtendedContext.copy_sign(1, -2)
Decimal('-1')
>>> ExtendedContext.copy_sign(Decimal(1), -2)
Decimal('-1')
>>> ExtendedContext.copy_sign(1, Decimal(-2))
Decimal('-1')
"""
a = _convert_other(a, raiseit=True)
return a.copy_sign(b)
def divide(self, a, b):
"""Decimal division in a specified context.
>>> ExtendedContext.divide(Decimal('1'), Decimal('3'))
Decimal('0.333333333')
>>> ExtendedContext.divide(Decimal('2'), Decimal('3'))
Decimal('0.666666667')
>>> ExtendedContext.divide(Decimal('5'), Decimal('2'))
Decimal('2.5')
>>> ExtendedContext.divide(Decimal('1'), Decimal('10'))
Decimal('0.1')
>>> ExtendedContext.divide(Decimal('12'), Decimal('12'))
Decimal('1')
>>> ExtendedContext.divide(Decimal('8.00'), Decimal('2'))
Decimal('4.00')
>>> ExtendedContext.divide(Decimal('2.400'), Decimal('2.0'))
Decimal('1.20')
>>> ExtendedContext.divide(Decimal('1000'), Decimal('100'))
Decimal('10')
>>> ExtendedContext.divide(Decimal('1000'), Decimal('1'))
Decimal('1000')
>>> ExtendedContext.divide(Decimal('2.40E+6'), Decimal('2'))
Decimal('1.20E+6')
>>> ExtendedContext.divide(5, 5)
Decimal('1')
>>> ExtendedContext.divide(Decimal(5), 5)
Decimal('1')
>>> ExtendedContext.divide(5, Decimal(5))
Decimal('1')
"""
a = _convert_other(a, raiseit=True)
r = a.__div__(b, context=self)
if r is NotImplemented:
raise TypeError("Unable to convert %s to Decimal" % b)
else:
return r
def divide_int(self, a, b):
"""Divides two numbers and returns the integer part of the result.
>>> ExtendedContext.divide_int(Decimal('2'), Decimal('3'))
Decimal('0')
>>> ExtendedContext.divide_int(Decimal('10'), Decimal('3'))
Decimal('3')
>>> ExtendedContext.divide_int(Decimal('1'), Decimal('0.3'))
Decimal('3')
>>> ExtendedContext.divide_int(10, 3)
Decimal('3')
>>> ExtendedContext.divide_int(Decimal(10), 3)
Decimal('3')
>>> ExtendedContext.divide_int(10, Decimal(3))
Decimal('3')
"""
a = _convert_other(a, raiseit=True)
r = a.__floordiv__(b, context=self)
if r is NotImplemented:
raise TypeError("Unable to convert %s to Decimal" % b)
else:
return r
def divmod(self, a, b):
"""Return (a // b, a % b).
>>> ExtendedContext.divmod(Decimal(8), Decimal(3))
(Decimal('2'), Decimal('2'))
>>> ExtendedContext.divmod(Decimal(8), Decimal(4))
(Decimal('2'), Decimal('0'))
>>> ExtendedContext.divmod(8, 4)
(Decimal('2'), Decimal('0'))
>>> ExtendedContext.divmod(Decimal(8), 4)
(Decimal('2'), Decimal('0'))
>>> ExtendedContext.divmod(8, Decimal(4))
(Decimal('2'), Decimal('0'))
"""
a = _convert_other(a, raiseit=True)
r = a.__divmod__(b, context=self)
if r is NotImplemented:
raise TypeError("Unable to convert %s to Decimal" % b)
else:
return r
def exp(self, a):
"""Returns e ** a.
>>> c = ExtendedContext.copy()
>>> c.Emin = -999
>>> c.Emax = 999
>>> c.exp(Decimal('-Infinity'))
Decimal('0')
>>> c.exp(Decimal('-1'))
Decimal('0.367879441')
>>> c.exp(Decimal('0'))
Decimal('1')
>>> c.exp(Decimal('1'))
Decimal('2.71828183')
>>> c.exp(Decimal('0.693147181'))
Decimal('2.00000000')
>>> c.exp(Decimal('+Infinity'))
Decimal('Infinity')
>>> c.exp(10)
Decimal('22026.4658')
"""
a =_convert_other(a, raiseit=True)
return a.exp(context=self)
def fma(self, a, b, c):
"""Returns a multiplied by b, plus c.
The first two operands are multiplied together, using multiply,
the third operand is then added to the result of that
multiplication, using add, all with only one final rounding.
>>> ExtendedContext.fma(Decimal('3'), Decimal('5'), Decimal('7'))
Decimal('22')
>>> ExtendedContext.fma(Decimal('3'), Decimal('-5'), Decimal('7'))
Decimal('-8')
>>> ExtendedContext.fma(Decimal('888565290'), Decimal('1557.96930'), Decimal('-86087.7578'))
Decimal('1.38435736E+12')
>>> ExtendedContext.fma(1, 3, 4)
Decimal('7')
>>> ExtendedContext.fma(1, Decimal(3), 4)
Decimal('7')
>>> ExtendedContext.fma(1, 3, Decimal(4))
Decimal('7')
"""
a = _convert_other(a, raiseit=True)
return a.fma(b, c, context=self)
def is_canonical(self, a):
"""Return True if the operand is canonical; otherwise return False.
Currently, the encoding of a Decimal instance is always
canonical, so this method returns True for any Decimal.
>>> ExtendedContext.is_canonical(Decimal('2.50'))
True
"""
return a.is_canonical()
def is_finite(self, a):
"""Return True if the operand is finite; otherwise return False.
A Decimal instance is considered finite if it is neither
infinite nor a NaN.
>>> ExtendedContext.is_finite(Decimal('2.50'))
True
>>> ExtendedContext.is_finite(Decimal('-0.3'))
True
>>> ExtendedContext.is_finite(Decimal('0'))
True
>>> ExtendedContext.is_finite(Decimal('Inf'))
False
>>> ExtendedContext.is_finite(Decimal('NaN'))
False
>>> ExtendedContext.is_finite(1)
True
"""
a = _convert_other(a, raiseit=True)
return a.is_finite()
def is_infinite(self, a):
"""Return True if the operand is infinite; otherwise return False.
>>> ExtendedContext.is_infinite(Decimal('2.50'))
False
>>> ExtendedContext.is_infinite(Decimal('-Inf'))
True
>>> ExtendedContext.is_infinite(Decimal('NaN'))
False
>>> ExtendedContext.is_infinite(1)
False
"""
a = _convert_other(a, raiseit=True)
return a.is_infinite()
def is_nan(self, a):
"""Return True if the operand is a qNaN or sNaN;
otherwise return False.
>>> ExtendedContext.is_nan(Decimal('2.50'))
False
>>> ExtendedContext.is_nan(Decimal('NaN'))
True
>>> Ext